Generalizing sloppy_interleave, making sloppiness an option.
PiperOrigin-RevId: 173687797
This commit is contained in:
parent
7775a66043
commit
6b05b36cd2
@ -50,6 +50,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element
|
||||
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
|
||||
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
|
||||
from tensorflow.contrib.data.python.ops.grouping import group_by_window
|
||||
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
|
||||
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
|
||||
from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset
|
||||
from tensorflow.contrib.data.python.ops.readers import read_batch_features
|
||||
@ -57,7 +58,6 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset
|
||||
from tensorflow.contrib.data.python.ops.readers import TextLineDataset
|
||||
from tensorflow.contrib.data.python.ops.readers import TFRecordDataset
|
||||
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
|
||||
from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave
|
||||
from tensorflow.python.data.ops.iterator_ops import Iterator
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
@ -143,6 +143,29 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "interleave_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["interleave_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"manual", # b/67958761
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "iterator_ops_cluster_test",
|
||||
size = "small",
|
||||
@ -352,29 +375,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sloppy_transformation_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["sloppy_transformation_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"manual", # b/67958761
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sql_dataset_op_test",
|
||||
size = "small",
|
||||
|
@ -25,7 +25,7 @@ import time
|
||||
from six.moves import zip_longest
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import sloppy_ops
|
||||
from tensorflow.contrib.data.python.ops import interleave_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -34,12 +34,13 @@ from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
class ParallelInterleaveDatasetTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
|
||||
self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
|
||||
|
||||
self.repeat_count = 2
|
||||
|
||||
@ -69,9 +70,9 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
|
||||
self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values)
|
||||
.repeat(self.repeat_count).apply(
|
||||
sloppy_ops.sloppy_interleave(
|
||||
interleave_ops.parallel_interleave(
|
||||
interleave_fn, self.cycle_length,
|
||||
self.block_length)))
|
||||
self.block_length, self.sloppy)))
|
||||
self.iterator = self.dataset.make_initializable_iterator()
|
||||
self.init_op = self.iterator.initializer
|
||||
self.next_element = self.iterator.get_next()
|
||||
@ -161,7 +162,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
for i in range(4, 7):
|
||||
self.write_coordination_events[i].set()
|
||||
|
||||
def testSingleThreaded(self):
|
||||
def _testSingleThreaded(self, sloppy=False):
|
||||
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
|
||||
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
|
||||
with self.test_session() as sess:
|
||||
@ -171,7 +172,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 1,
|
||||
self.block_length: 1
|
||||
self.block_length: 1,
|
||||
self.sloppy: sloppy
|
||||
})
|
||||
|
||||
for expected_element in self._interleave(
|
||||
@ -182,7 +184,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testTwoThreadsNoContention(self):
|
||||
def testSingleThreaded(self):
|
||||
self._testSingleThreaded()
|
||||
|
||||
def testSingleThreadedSloppy(self):
|
||||
self._testSingleThreaded(sloppy=True)
|
||||
|
||||
def _testTwoThreadsNoContention(self, sloppy=False):
|
||||
# num_threads > 1.
|
||||
# Explicit coordination should result in `Dataset.interleave()` behavior
|
||||
with self.test_session() as sess:
|
||||
@ -193,7 +201,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1
|
||||
self.block_length: 1,
|
||||
self.sloppy: sloppy
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
@ -211,43 +220,59 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testTwoThreadsNoContention(self):
|
||||
self._testTwoThreadsNoContention()
|
||||
|
||||
def testTwoThreadsNoContentionSloppy(self):
|
||||
self._testTwoThreadsNoContention(sloppy=True)
|
||||
|
||||
def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
|
||||
"""Tests where all the workers race in producing elements.
|
||||
|
||||
Note: this is in contrast with the prevous test which carefully sequences
|
||||
the execution of the map functions.
|
||||
|
||||
Args:
|
||||
sloppy: Whether to be sloppy or not.
|
||||
"""
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
done_first_event = False
|
||||
sess.run(
|
||||
self.init_op,
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1,
|
||||
self.sloppy: sloppy,
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
1)):
|
||||
if done_first_event: # First event starts the worker threads.
|
||||
self._allow_all_map_threads()
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
else:
|
||||
self.write_coordination_events[expected_element].set()
|
||||
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
|
||||
actual_element = sess.run(self.next_element)
|
||||
if not done_first_event:
|
||||
done_first_event = True
|
||||
self.assertTrue(
|
||||
self.read_coordination_events[expected_element].acquire(False))
|
||||
self.assertEqual(expected_element * expected_element, actual_element,
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testTwoThreadsNoContentionWithRaces(self):
|
||||
"""Tests where all the workers race in producing elements.
|
||||
self._testTwoThreadsNoContentionWithRaces()
|
||||
|
||||
Note: this is in contrast with the prevous test which carefully sequences
|
||||
the execution of the map functions.
|
||||
"""
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
done_first_event = False
|
||||
sess.run(
|
||||
self.init_op,
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
1)):
|
||||
if done_first_event: # First event starts the worker threads.
|
||||
self._allow_all_map_threads()
|
||||
self.read_coordination_events[expected_element].acquire()
|
||||
else:
|
||||
self.write_coordination_events[expected_element].set()
|
||||
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
|
||||
actual_element = sess.run(self.next_element)
|
||||
if not done_first_event:
|
||||
done_first_event = True
|
||||
self.assertTrue(
|
||||
self.read_coordination_events[expected_element].acquire(False))
|
||||
self.assertEqual(expected_element * expected_element, actual_element,
|
||||
"At index %s: %s expected, got: %s" %
|
||||
(i, expected_element, actual_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
def testTwoThreadsNoContentionWithRacesSloppy(self):
|
||||
self._testTwoThreadsNoContentionWithRaces(sloppy=True)
|
||||
|
||||
def testTwoThreadsNoContentionBlockLength(self):
|
||||
def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
|
||||
# num_threads > 1.
|
||||
# Explicit coordination should result in `Dataset.interleave()` behavior
|
||||
with self.test_session() as sess:
|
||||
@ -258,7 +283,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 2
|
||||
self.block_length: 2,
|
||||
self.sloppy: sloppy
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
@ -276,11 +302,21 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
|
||||
def testTwoThreadsNoContentionBlockLength(self):
|
||||
self._testTwoThreadsNoContentionBlockLength()
|
||||
|
||||
def testTwoThreadsNoContentionBlockLengthSloppy(self):
|
||||
self._testTwoThreadsNoContentionBlockLength(sloppy=True)
|
||||
|
||||
def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
|
||||
"""Tests where all the workers race in producing elements.
|
||||
|
||||
Note: this is in contrast with the prevous test which carefully sequences
|
||||
the execution of the map functions.
|
||||
|
||||
|
||||
Args:
|
||||
sloppy: Whether to be sloppy or not.
|
||||
"""
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
@ -290,7 +326,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 2
|
||||
self.block_length: 2,
|
||||
self.sloppy: sloppy
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||
@ -312,7 +349,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testEmptyInput(self):
|
||||
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
|
||||
self._testTwoThreadsNoContentionWithRacesAndBlocking()
|
||||
|
||||
def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
|
||||
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
|
||||
|
||||
def _testEmptyInput(self, sloppy=False):
|
||||
with self.test_session() as sess:
|
||||
# Empty input.
|
||||
self._clear_coordination_events()
|
||||
@ -321,12 +364,19 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 3
|
||||
self.block_length: 3,
|
||||
self.sloppy: sloppy
|
||||
})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testNonEmptyInputIntoEmptyOutputs(self):
|
||||
def testEmptyInput(self):
|
||||
self._testEmptyInput()
|
||||
|
||||
def testEmptyInputSloppy(self):
|
||||
self._testEmptyInput(sloppy=True)
|
||||
|
||||
def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
|
||||
# Non-empty input leading to empty output.
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
@ -335,12 +385,19 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [0, 0, 0],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 3
|
||||
self.block_length: 3,
|
||||
self.sloppy: sloppy
|
||||
})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testPartiallyEmptyOutputs(self):
|
||||
def testNonEmptyInputIntoEmptyOutputs(self):
|
||||
self._testNonEmptyInputIntoEmptyOutputs()
|
||||
|
||||
def testNonEmptyInputIntoEmptyOutputsSloppy(self):
|
||||
self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
|
||||
|
||||
def _testPartiallyEmptyOutputs(self, sloppy=False):
|
||||
# Mixture of non-empty and empty interleaved datasets.
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
@ -350,7 +407,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [4, 0, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1
|
||||
self.block_length: 1,
|
||||
self.sloppy: sloppy,
|
||||
})
|
||||
for i, expected_element in enumerate(
|
||||
self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
|
||||
@ -367,7 +425,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testDelayedOutput(self):
|
||||
def testPartiallyEmptyOutputs(self):
|
||||
self._testPartiallyEmptyOutputs()
|
||||
|
||||
def testPartiallyEmptyOutputsSloppy(self):
|
||||
self._testPartiallyEmptyOutputs(sloppy=True)
|
||||
|
||||
def testDelayedOutputSloppy(self):
|
||||
# Explicitly control the sequence of events to ensure we correctly avoid
|
||||
# head-of-line blocking.
|
||||
with self.test_session() as sess:
|
||||
@ -377,7 +441,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 1
|
||||
self.block_length: 1,
|
||||
self.sloppy: True,
|
||||
})
|
||||
|
||||
mis_ordering = [
|
||||
@ -391,7 +456,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testBlockLengthWithContention(self):
|
||||
def testBlockLengthWithContentionSloppy(self):
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
done_first_event = False
|
||||
@ -400,7 +465,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 2,
|
||||
self.block_length: 3
|
||||
self.block_length: 3,
|
||||
self.sloppy: True
|
||||
})
|
||||
# Test against a generating sequence that differs from the uncontended
|
||||
# case, in order to prove sloppy correctness.
|
||||
@ -422,7 +488,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.next_element)
|
||||
|
||||
def testEarlyExit(self):
|
||||
def _testEarlyExit(self, sloppy=False):
|
||||
# Exiting without consuming all input should not block
|
||||
with self.test_session() as sess:
|
||||
self._clear_coordination_events()
|
||||
@ -431,7 +497,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
feed_dict={
|
||||
self.input_values: [4, 5, 6],
|
||||
self.cycle_length: 3,
|
||||
self.block_length: 2
|
||||
self.block_length: 2,
|
||||
self.sloppy: sloppy
|
||||
})
|
||||
for i in range(4, 7):
|
||||
self.write_coordination_events[i].set()
|
||||
@ -445,7 +512,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
self.read_coordination_events[i].acquire()
|
||||
self.write_coordination_events[i].set()
|
||||
|
||||
def testTooManyReaders(self):
|
||||
def testEarlyExit(self):
|
||||
self._testEarlyExit()
|
||||
|
||||
def testEarlyExitSloppy(self):
|
||||
self._testEarlyExit(sloppy=True)
|
||||
|
||||
def _testTooManyReaders(self, sloppy=False):
|
||||
|
||||
def interleave_fn(x):
|
||||
dataset = dataset_ops.Dataset.from_tensors(x)
|
||||
@ -455,8 +528,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
|
||||
dataset = dataset.repeat(self.repeat_count)
|
||||
dataset = dataset.apply(
|
||||
sloppy_ops.sloppy_interleave(interleave_fn, cycle_length=16,
|
||||
block_length=2))
|
||||
interleave_ops.parallel_interleave(
|
||||
interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
|
||||
with self.test_session() as sess:
|
||||
@ -468,6 +541,11 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
||||
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
|
||||
self.assertItemsEqual(output_values, expected_values)
|
||||
|
||||
def testTooManyReaders(self):
|
||||
self._testTooManyReaders()
|
||||
|
||||
def testTooManyReadersSloppy(self):
|
||||
self._testTooManyReaders(sloppy=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -60,9 +60,9 @@ py_library(
|
||||
"enumerate_ops.py",
|
||||
"error_ops.py",
|
||||
"grouping.py",
|
||||
"interleave_ops.py",
|
||||
"resampling.py",
|
||||
"scan_ops.py",
|
||||
"sloppy_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
|
@ -23,14 +23,16 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
|
||||
|
||||
class SloppyInterleaveDataset(dataset_ops.Dataset):
|
||||
class ParallelInterleaveDataset(dataset_ops.Dataset):
|
||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length):
|
||||
"""See `tf.contrib.data.sloppy_interleave()` for details."""
|
||||
super(SloppyInterleaveDataset, self).__init__()
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
sloppy):
|
||||
"""See `tf.contrib.data.parallel_interleave()` for details."""
|
||||
super(ParallelInterleaveDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
@function.Defun(*nest.flatten(input_dataset.output_types))
|
||||
@ -62,13 +64,16 @@ class SloppyInterleaveDataset(dataset_ops.Dataset):
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._sloppy = ops.convert_to_tensor(
|
||||
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.sloppy_interleave_dataset(
|
||||
return gen_dataset_ops.parallel_interleave_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
self._map_func.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
f=self._map_func,
|
||||
output_types=nest.flatten(self.output_types),
|
||||
output_shapes=nest.flatten(self.output_shapes))
|
||||
@ -82,6 +87,53 @@ class SloppyInterleaveDataset(dataset_ops.Dataset):
|
||||
return self._output_types
|
||||
|
||||
|
||||
def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
|
||||
"""A parallel version of the `Dataset.interleave()` transformation.
|
||||
|
||||
`parallel_interleave()` maps `map_func` across its input to produce nested
|
||||
datasets, and outputs their elements interleaved. Unlike
|
||||
@{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested
|
||||
datasets in parallel, which increases the throughput, especially in the
|
||||
presence of stragglers. Furthermore, the `sloppy` argument can be used to
|
||||
improve performance, by relaxing the requirement that the outputs are produced
|
||||
in a deterministic order, and allowing the implementation to skip over nested
|
||||
datasets whose elements are not readily available when requested.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
# Preprocess 4 files concurrently.
|
||||
filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
|
||||
dataset = filenames.apply(
|
||||
tf.contrib.data.parallel_interleave(
|
||||
lambda filename: tf.data.TFRecordDataset(filename),
|
||||
cycle_length=4))
|
||||
```
|
||||
|
||||
WARNING: If `sloppy` is `True`, the order of produced elements is not
|
||||
deterministic.
|
||||
|
||||
Args:
|
||||
map_func: A function mapping a nested structure of tensors to a `Dataset`.
|
||||
cycle_length: The number of threads to interleave from in parallel.
|
||||
block_length: The number of consecutive elements to pull from a thread
|
||||
before advancing to the next thread.
|
||||
sloppy: If false, elements are produced in deterministic order. Otherwise,
|
||||
the implementation is allowed, for the sake of expediency, to produce
|
||||
elements in a non-deterministic order.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@{tf.data.Dataset.apply}.
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return ParallelInterleaveDataset(
|
||||
dataset, map_func, cycle_length, block_length, sloppy)
|
||||
return _apply_fn
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.")
|
||||
def sloppy_interleave(map_func, cycle_length, block_length=1):
|
||||
"""A non-deterministic version of the `Dataset.interleave()` transformation.
|
||||
|
||||
@ -132,6 +184,6 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
|
||||
@{tf.data.Dataset.apply}.
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return SloppyInterleaveDataset(
|
||||
dataset, map_func, cycle_length, block_length)
|
||||
return ParallelInterleaveDataset(
|
||||
dataset, map_func, cycle_length, block_length, sloppy=True)
|
||||
return _apply_fn
|
@ -5924,8 +5924,8 @@ tf_kernel_library(
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "sloppy_interleave_dataset_op",
|
||||
srcs = ["sloppy_interleave_dataset_op.cc"],
|
||||
name = "parallel_interleave_dataset_op",
|
||||
srcs = ["parallel_interleave_dataset_op.cc"],
|
||||
deps = [
|
||||
":captured_function",
|
||||
":dataset",
|
||||
@ -6162,6 +6162,7 @@ tf_kernel_library(
|
||||
":map_and_batch_dataset_op",
|
||||
":map_dataset_op",
|
||||
":padded_batch_dataset_op",
|
||||
":parallel_interleave_dataset_op",
|
||||
":parallel_map_dataset_op",
|
||||
":prefetch_dataset_op",
|
||||
":range_dataset_op",
|
||||
@ -6170,7 +6171,6 @@ tf_kernel_library(
|
||||
":scan_dataset_op",
|
||||
":shuffle_dataset_op",
|
||||
":skip_dataset_op",
|
||||
":sloppy_interleave_dataset_op",
|
||||
":sparse_tensor_slice_dataset_op",
|
||||
":sql_dataset_ops",
|
||||
":take_dataset_op",
|
||||
|
@ -336,7 +336,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||
const Eigen::ThreadPoolDevice* device_; // not owned
|
||||
const Eigen::ThreadPoolDevice* device_; // not owned
|
||||
};
|
||||
|
||||
const int graph_def_version_;
|
||||
|
@ -17,12 +17,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/captured_function.h"
|
||||
#include "tensorflow/core/kernels/dataset_utils.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
|
||||
#include "tensorflow/core/kernels/captured_function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
@ -30,9 +29,9 @@ namespace {
|
||||
// See documentation in ../ops/dataset_ops.cc for a high-level
|
||||
// description of the following op.
|
||||
|
||||
class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit SloppyInterleaveDatasetOp(OpKernelConstruction* ctx)
|
||||
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx),
|
||||
graph_def_version_(ctx->graph_def_version()) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
|
||||
@ -62,13 +61,16 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
OP_REQUIRES(ctx, block_length > 0,
|
||||
errors::InvalidArgument("`block_length` must be > 0"));
|
||||
|
||||
bool sloppy;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
|
||||
|
||||
std::unique_ptr<CapturedFunction> captured_func;
|
||||
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_,
|
||||
std::move(other_arguments),
|
||||
&captured_func));
|
||||
|
||||
*output = new Dataset(input, std::move(captured_func), cycle_length,
|
||||
block_length, output_types_, output_shapes_);
|
||||
block_length, sloppy, output_types_, output_shapes_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -76,12 +78,13 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
Dataset(const DatasetBase* input,
|
||||
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
||||
int64 block_length, const DataTypeVector& output_types,
|
||||
int64 block_length, bool sloppy, const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: input_(input),
|
||||
captured_func_(std::move(captured_func)),
|
||||
cycle_length_(cycle_length),
|
||||
block_length_(block_length),
|
||||
sloppy_(sloppy),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes) {
|
||||
input_->Ref();
|
||||
@ -91,8 +94,8 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIterator(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, strings::StrCat(prefix, "::SloppyInterleave")}));
|
||||
return std::unique_ptr<IteratorBase>(new Iterator(
|
||||
{this, strings::StrCat(prefix, "::ParallelInterleave")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
@ -103,7 +106,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
string DebugString() override {
|
||||
return "SloppyInterleaveDatasetOp::Dataset";
|
||||
return "ParallelInterleaveDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
private:
|
||||
@ -131,16 +134,24 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
|
||||
// Search for available items, blocking if necessary.
|
||||
const int64 num_workers = worker_threads_.size();
|
||||
if (num_workers == 0) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
while (!cancelled_) {
|
||||
for (size_t i = 0; i < dataset()->cycle_length_; ++i) {
|
||||
size_t index = (next_index_ + i) % dataset()->cycle_length_;
|
||||
// Wait for an item to become available, blocking if necessary. If we
|
||||
// are allowed to be sloppy, we can skip over input datasets that do
|
||||
// not have an item readily available.
|
||||
const int64 n = dataset()->sloppy_ ? num_workers : 1LL;
|
||||
for (int64 i = 0; i < n; ++i) {
|
||||
int64 index = (next_index_ + i) % num_workers;
|
||||
if (output_elements_[index].is_produced) {
|
||||
next_index_ = index;
|
||||
if (i == 0) {
|
||||
block_count_++;
|
||||
if (block_count_ == dataset()->block_length_) {
|
||||
next_index_ = (index + 1) % dataset()->cycle_length_;
|
||||
next_index_ = (index + 1) % num_workers;
|
||||
block_count_ = 0;
|
||||
}
|
||||
} else {
|
||||
@ -150,7 +161,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (output_elements_[index].end_of_sequence) {
|
||||
output_elements_[index].is_produced = false;
|
||||
output_elements_[index].cond_var.notify_one();
|
||||
next_index_ = (index + 1) % dataset()->cycle_length_;
|
||||
next_index_ = (index + 1) % num_workers;
|
||||
block_count_ = 0;
|
||||
i = -1; // Restart the inner loop
|
||||
continue;
|
||||
@ -174,11 +185,21 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If we are not allowed to be sloppy and
|
||||
// `worker_threads_[next_index]` has finished, advance `next_index`.
|
||||
if (!dataset()->sloppy_ && worker_threads_[next_index_].finished) {
|
||||
next_index_ = (next_index_ + 1) % num_workers;
|
||||
continue;
|
||||
}
|
||||
|
||||
// No values available; wait until woken up.
|
||||
// TODO(jsimsa): Use slot-specific condition variable for
|
||||
// coordination of elements consumption.
|
||||
cond_var_.wait(l);
|
||||
}
|
||||
return errors::Cancelled(
|
||||
"SloppyInterleaveDatasetOp::Dataset::Iterator::GetNext");
|
||||
"ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
|
||||
}
|
||||
|
||||
private:
|
||||
@ -201,6 +222,16 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
condition_variable cond_var;
|
||||
};
|
||||
|
||||
struct ThreadStatus {
|
||||
// The underlying thread uses `finished` to communicate to the producer
|
||||
// that it has finished.
|
||||
bool finished = false;
|
||||
// The underlying thread object.
|
||||
std::unique_ptr<Thread> thread;
|
||||
|
||||
explicit ThreadStatus(Thread* thread) : thread(thread) {}
|
||||
};
|
||||
|
||||
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (worker_threads_.empty()) {
|
||||
@ -220,11 +251,10 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::unique_ptr<IteratorBase> itr;
|
||||
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
|
||||
ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr));
|
||||
worker_threads_.emplace_back(
|
||||
std::unique_ptr<Thread>(ctx->env()->StartThread(
|
||||
{}, "worker_thread",
|
||||
std::bind(&Iterator::WorkerThread, this,
|
||||
new IteratorContext(*ctx), i, itr.release()))));
|
||||
worker_threads_.emplace_back(ctx->env()->StartThread(
|
||||
{}, "worker_thread",
|
||||
std::bind(&Iterator::WorkerThread, this,
|
||||
new IteratorContext(*ctx), i, itr.release())));
|
||||
num_active_threads_ = i + 1;
|
||||
}
|
||||
}
|
||||
@ -264,6 +294,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr);
|
||||
auto cleanup = gtl::MakeCleanup([this, thread_index] {
|
||||
mutex_lock l(mu_);
|
||||
worker_threads_[thread_index].finished = true;
|
||||
num_active_threads_--;
|
||||
cond_var_.notify_all();
|
||||
});
|
||||
@ -345,13 +376,14 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Pointers to the worker threads. This must be last to ensure the
|
||||
// threads have exited before any other members are deallocated.
|
||||
// TODO(b/65178177): Avoid allocating additional threads.
|
||||
std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
|
||||
std::vector<ThreadStatus> worker_threads_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||
const int64 cycle_length_;
|
||||
const int64 block_length_;
|
||||
const bool sloppy_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
@ -362,8 +394,8 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
NameAttrList func_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SloppyInterleaveDataset").Device(DEVICE_CPU),
|
||||
SloppyInterleaveDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
|
||||
} // namespace
|
||||
|
@ -59,7 +59,6 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
||||
Dataset(const DatasetBase* input, int64 buffer_size,
|
||||
IteratorContext::Params ctx_params)
|
||||
: input_(input),
|
||||
|
||||
buffer_size_(buffer_size),
|
||||
ctx_params_(std::move(ctx_params)) {
|
||||
input_->Ref();
|
||||
|
@ -32629,95 +32629,6 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "SloppyInterleaveDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "other_arguments"
|
||||
type_list_attr: "Targuments"
|
||||
}
|
||||
input_arg {
|
||||
name: "cycle_length"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "block_length"
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "SloppyInterleaveDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "other_arguments"
|
||||
type_list_attr: "Targuments"
|
||||
}
|
||||
input_arg {
|
||||
name: "cycle_length"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "block_length"
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Softmax"
|
||||
input_arg {
|
||||
|
@ -285,11 +285,12 @@ f: A function mapping elements of `input_dataset`, concatenated with
|
||||
`output_types` and `output_shapes`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SloppyInterleaveDataset")
|
||||
REGISTER_OP("ParallelInterleaveDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Input("cycle_length: int64")
|
||||
.Input("block_length: int64")
|
||||
.Input("sloppy: bool")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
|
Loading…
Reference in New Issue
Block a user