Generalizing sloppy_interleave, making sloppiness an option.

PiperOrigin-RevId: 173687797
This commit is contained in:
Jiri Simsa 2017-10-27 10:29:36 -07:00 committed by TensorFlower Gardener
parent 7775a66043
commit 6b05b36cd2
11 changed files with 283 additions and 210 deletions

View File

@ -50,6 +50,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.grouping import group_by_window
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features from tensorflow.contrib.data.python.ops.readers import read_batch_features
@ -57,7 +58,6 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset
from tensorflow.contrib.data.python.ops.readers import TextLineDataset from tensorflow.contrib.data.python.ops.readers import TextLineDataset
from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.readers import TFRecordDataset
from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave
from tensorflow.python.data.ops.iterator_ops import Iterator from tensorflow.python.data.ops.iterator_ops import Iterator
# pylint: enable=unused-import # pylint: enable=unused-import

View File

@ -143,6 +143,29 @@ py_test(
], ],
) )
py_test(
name = "interleave_dataset_op_test",
size = "small",
srcs = ["interleave_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual", # b/67958761
],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:training",
"//third_party/py/numpy",
],
)
py_test( py_test(
name = "iterator_ops_cluster_test", name = "iterator_ops_cluster_test",
size = "small", size = "small",
@ -352,29 +375,6 @@ py_test(
], ],
) )
py_test(
name = "sloppy_transformation_dataset_op_test",
size = "small",
srcs = ["sloppy_transformation_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual", # b/67958761
],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:training",
"//third_party/py/numpy",
],
)
py_test( py_test(
name = "sql_dataset_op_test", name = "sql_dataset_op_test",
size = "small", size = "small",

View File

@ -25,7 +25,7 @@ import time
from six.moves import zip_longest from six.moves import zip_longest
from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import sloppy_ops from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -34,12 +34,13 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
class SloppyInterleaveDatasetTest(test.TestCase): class ParallelInterleaveDatasetTest(test.TestCase):
def setUp(self): def setUp(self):
self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
self.repeat_count = 2 self.repeat_count = 2
@ -69,9 +70,9 @@ class SloppyInterleaveDatasetTest(test.TestCase):
self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values) self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values)
.repeat(self.repeat_count).apply( .repeat(self.repeat_count).apply(
sloppy_ops.sloppy_interleave( interleave_ops.parallel_interleave(
interleave_fn, self.cycle_length, interleave_fn, self.cycle_length,
self.block_length))) self.block_length, self.sloppy)))
self.iterator = self.dataset.make_initializable_iterator() self.iterator = self.dataset.make_initializable_iterator()
self.init_op = self.iterator.initializer self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next() self.next_element = self.iterator.get_next()
@ -161,7 +162,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
for i in range(4, 7): for i in range(4, 7):
self.write_coordination_events[i].set() self.write_coordination_events[i].set()
def testSingleThreaded(self): def _testSingleThreaded(self, sloppy=False):
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
# `Dataset.flat_map()` and is single-threaded. No synchronization required. # `Dataset.flat_map()` and is single-threaded. No synchronization required.
with self.test_session() as sess: with self.test_session() as sess:
@ -171,7 +172,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [4, 5, 6], self.input_values: [4, 5, 6],
self.cycle_length: 1, self.cycle_length: 1,
self.block_length: 1 self.block_length: 1,
self.sloppy: sloppy
}) })
for expected_element in self._interleave( for expected_element in self._interleave(
@ -182,7 +184,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element) sess.run(self.next_element)
def testTwoThreadsNoContention(self): def testSingleThreaded(self):
self._testSingleThreaded()
def testSingleThreadedSloppy(self):
self._testSingleThreaded(sloppy=True)
def _testTwoThreadsNoContention(self, sloppy=False):
# num_threads > 1. # num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior # Explicit coordination should result in `Dataset.interleave()` behavior
with self.test_session() as sess: with self.test_session() as sess:
@ -193,7 +201,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [4, 5, 6], self.input_values: [4, 5, 6],
self.cycle_length: 2, self.cycle_length: 2,
self.block_length: 1 self.block_length: 1,
self.sloppy: sloppy
}) })
for i, expected_element in enumerate( for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -211,43 +220,59 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element) sess.run(self.next_element)
def testTwoThreadsNoContention(self):
self._testTwoThreadsNoContention()
def testTwoThreadsNoContentionSloppy(self):
self._testTwoThreadsNoContention(sloppy=True)
def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
"""Tests where all the workers race in producing elements.
Note: this is in contrast with the prevous test which carefully sequences
the execution of the map functions.
Args:
sloppy: Whether to be sloppy or not.
"""
with self.test_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
self.init_op,
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: sloppy,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
1)):
if done_first_event: # First event starts the worker threads.
self._allow_all_map_threads()
self.read_coordination_events[expected_element].acquire()
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
if not done_first_event:
done_first_event = True
self.assertTrue(
self.read_coordination_events[expected_element].acquire(False))
self.assertEqual(expected_element * expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testTwoThreadsNoContentionWithRaces(self): def testTwoThreadsNoContentionWithRaces(self):
"""Tests where all the workers race in producing elements. self._testTwoThreadsNoContentionWithRaces()
Note: this is in contrast with the prevous test which carefully sequences def testTwoThreadsNoContentionWithRacesSloppy(self):
the execution of the map functions. self._testTwoThreadsNoContentionWithRaces(sloppy=True)
"""
with self.test_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
self.init_op,
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
1)):
if done_first_event: # First event starts the worker threads.
self._allow_all_map_threads()
self.read_coordination_events[expected_element].acquire()
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
if not done_first_event:
done_first_event = True
self.assertTrue(
self.read_coordination_events[expected_element].acquire(False))
self.assertEqual(expected_element * expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testTwoThreadsNoContentionBlockLength(self): def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
# num_threads > 1. # num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior # Explicit coordination should result in `Dataset.interleave()` behavior
with self.test_session() as sess: with self.test_session() as sess:
@ -258,7 +283,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [4, 5, 6], self.input_values: [4, 5, 6],
self.cycle_length: 2, self.cycle_length: 2,
self.block_length: 2 self.block_length: 2,
self.sloppy: sloppy
}) })
for i, expected_element in enumerate( for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -276,11 +302,21 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element) sess.run(self.next_element)
def testTwoThreadsNoContentionWithRacesAndBlocking(self): def testTwoThreadsNoContentionBlockLength(self):
self._testTwoThreadsNoContentionBlockLength()
def testTwoThreadsNoContentionBlockLengthSloppy(self):
self._testTwoThreadsNoContentionBlockLength(sloppy=True)
def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
"""Tests where all the workers race in producing elements. """Tests where all the workers race in producing elements.
Note: this is in contrast with the prevous test which carefully sequences Note: this is in contrast with the prevous test which carefully sequences
the execution of the map functions. the execution of the map functions.
Args:
sloppy: Whether to be sloppy or not.
""" """
with self.test_session() as sess: with self.test_session() as sess:
self._clear_coordination_events() self._clear_coordination_events()
@ -290,7 +326,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [4, 5, 6], self.input_values: [4, 5, 6],
self.cycle_length: 2, self.cycle_length: 2,
self.block_length: 2 self.block_length: 2,
self.sloppy: sloppy
}) })
for i, expected_element in enumerate( for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -312,7 +349,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element) sess.run(self.next_element)
def testEmptyInput(self): def testTwoThreadsNoContentionWithRacesAndBlocking(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking()
def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
def _testEmptyInput(self, sloppy=False):
with self.test_session() as sess: with self.test_session() as sess:
# Empty input. # Empty input.
self._clear_coordination_events() self._clear_coordination_events()
@ -321,12 +364,19 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [], self.input_values: [],
self.cycle_length: 2, self.cycle_length: 2,
self.block_length: 3 self.block_length: 3,
self.sloppy: sloppy
}) })
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element) sess.run(self.next_element)
def testNonEmptyInputIntoEmptyOutputs(self): def testEmptyInput(self):
self._testEmptyInput()
def testEmptyInputSloppy(self):
self._testEmptyInput(sloppy=True)
def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
# Non-empty input leading to empty output. # Non-empty input leading to empty output.
with self.test_session() as sess: with self.test_session() as sess:
self._clear_coordination_events() self._clear_coordination_events()
@ -335,12 +385,19 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [0, 0, 0], self.input_values: [0, 0, 0],
self.cycle_length: 2, self.cycle_length: 2,
self.block_length: 3 self.block_length: 3,
self.sloppy: sloppy
}) })
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element) sess.run(self.next_element)
def testPartiallyEmptyOutputs(self): def testNonEmptyInputIntoEmptyOutputs(self):
self._testNonEmptyInputIntoEmptyOutputs()
def testNonEmptyInputIntoEmptyOutputsSloppy(self):
self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
def _testPartiallyEmptyOutputs(self, sloppy=False):
# Mixture of non-empty and empty interleaved datasets. # Mixture of non-empty and empty interleaved datasets.
with self.test_session() as sess: with self.test_session() as sess:
self._clear_coordination_events() self._clear_coordination_events()
@ -350,7 +407,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [4, 0, 6], self.input_values: [4, 0, 6],
self.cycle_length: 2, self.cycle_length: 2,
self.block_length: 1 self.block_length: 1,
self.sloppy: sloppy,
}) })
for i, expected_element in enumerate( for i, expected_element in enumerate(
self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
@ -367,7 +425,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element) sess.run(self.next_element)
def testDelayedOutput(self): def testPartiallyEmptyOutputs(self):
self._testPartiallyEmptyOutputs()
def testPartiallyEmptyOutputsSloppy(self):
self._testPartiallyEmptyOutputs(sloppy=True)
def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid # Explicitly control the sequence of events to ensure we correctly avoid
# head-of-line blocking. # head-of-line blocking.
with self.test_session() as sess: with self.test_session() as sess:
@ -377,7 +441,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [4, 5, 6], self.input_values: [4, 5, 6],
self.cycle_length: 2, self.cycle_length: 2,
self.block_length: 1 self.block_length: 1,
self.sloppy: True,
}) })
mis_ordering = [ mis_ordering = [
@ -391,7 +456,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element) sess.run(self.next_element)
def testBlockLengthWithContention(self): def testBlockLengthWithContentionSloppy(self):
with self.test_session() as sess: with self.test_session() as sess:
self._clear_coordination_events() self._clear_coordination_events()
done_first_event = False done_first_event = False
@ -400,7 +465,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [4, 5, 6], self.input_values: [4, 5, 6],
self.cycle_length: 2, self.cycle_length: 2,
self.block_length: 3 self.block_length: 3,
self.sloppy: True
}) })
# Test against a generating sequence that differs from the uncontended # Test against a generating sequence that differs from the uncontended
# case, in order to prove sloppy correctness. # case, in order to prove sloppy correctness.
@ -422,7 +488,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element) sess.run(self.next_element)
def testEarlyExit(self): def _testEarlyExit(self, sloppy=False):
# Exiting without consuming all input should not block # Exiting without consuming all input should not block
with self.test_session() as sess: with self.test_session() as sess:
self._clear_coordination_events() self._clear_coordination_events()
@ -431,7 +497,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={ feed_dict={
self.input_values: [4, 5, 6], self.input_values: [4, 5, 6],
self.cycle_length: 3, self.cycle_length: 3,
self.block_length: 2 self.block_length: 2,
self.sloppy: sloppy
}) })
for i in range(4, 7): for i in range(4, 7):
self.write_coordination_events[i].set() self.write_coordination_events[i].set()
@ -445,7 +512,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
self.read_coordination_events[i].acquire() self.read_coordination_events[i].acquire()
self.write_coordination_events[i].set() self.write_coordination_events[i].set()
def testTooManyReaders(self): def testEarlyExit(self):
self._testEarlyExit()
def testEarlyExitSloppy(self):
self._testEarlyExit(sloppy=True)
def _testTooManyReaders(self, sloppy=False):
def interleave_fn(x): def interleave_fn(x):
dataset = dataset_ops.Dataset.from_tensors(x) dataset = dataset_ops.Dataset.from_tensors(x)
@ -455,8 +528,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6]) dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
dataset = dataset.repeat(self.repeat_count) dataset = dataset.repeat(self.repeat_count)
dataset = dataset.apply( dataset = dataset.apply(
sloppy_ops.sloppy_interleave(interleave_fn, cycle_length=16, interleave_ops.parallel_interleave(
block_length=2)) interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_one_shot_iterator()
with self.test_session() as sess: with self.test_session() as sess:
@ -468,6 +541,11 @@ class SloppyInterleaveDatasetTest(test.TestCase):
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
self.assertItemsEqual(output_values, expected_values) self.assertItemsEqual(output_values, expected_values)
def testTooManyReaders(self):
self._testTooManyReaders()
def testTooManyReadersSloppy(self):
self._testTooManyReaders(sloppy=True)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -60,9 +60,9 @@ py_library(
"enumerate_ops.py", "enumerate_ops.py",
"error_ops.py", "error_ops.py",
"grouping.py", "grouping.py",
"interleave_ops.py",
"resampling.py", "resampling.py",
"scan_ops.py", "scan_ops.py",
"sloppy_ops.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [

View File

@ -23,14 +23,16 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function from tensorflow.python.framework import function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
class SloppyInterleaveDataset(dataset_ops.Dataset): class ParallelInterleaveDataset(dataset_ops.Dataset):
"""A `Dataset` that maps a function over its input and flattens the result.""" """A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func, cycle_length, block_length): def __init__(self, input_dataset, map_func, cycle_length, block_length,
"""See `tf.contrib.data.sloppy_interleave()` for details.""" sloppy):
super(SloppyInterleaveDataset, self).__init__() """See `tf.contrib.data.parallel_interleave()` for details."""
super(ParallelInterleaveDataset, self).__init__()
self._input_dataset = input_dataset self._input_dataset = input_dataset
@function.Defun(*nest.flatten(input_dataset.output_types)) @function.Defun(*nest.flatten(input_dataset.output_types))
@ -62,13 +64,16 @@ class SloppyInterleaveDataset(dataset_ops.Dataset):
cycle_length, dtype=dtypes.int64, name="cycle_length") cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor( self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length") block_length, dtype=dtypes.int64, name="block_length")
self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
def _as_variant_tensor(self): def _as_variant_tensor(self):
return gen_dataset_ops.sloppy_interleave_dataset( return gen_dataset_ops.parallel_interleave_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._map_func.captured_inputs, self._map_func.captured_inputs,
self._cycle_length, self._cycle_length,
self._block_length, self._block_length,
self._sloppy,
f=self._map_func, f=self._map_func,
output_types=nest.flatten(self.output_types), output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes)) output_shapes=nest.flatten(self.output_shapes))
@ -82,6 +87,53 @@ class SloppyInterleaveDataset(dataset_ops.Dataset):
return self._output_types return self._output_types
def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
"""A parallel version of the `Dataset.interleave()` transformation.
`parallel_interleave()` maps `map_func` across its input to produce nested
datasets, and outputs their elements interleaved. Unlike
@{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested
datasets in parallel, which increases the throughput, especially in the
presence of stragglers. Furthermore, the `sloppy` argument can be used to
improve performance, by relaxing the requirement that the outputs are produced
in a deterministic order, and allowing the implementation to skip over nested
datasets whose elements are not readily available when requested.
Example usage:
```python
# Preprocess 4 files concurrently.
filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
dataset = filenames.apply(
tf.contrib.data.parallel_interleave(
lambda filename: tf.data.TFRecordDataset(filename),
cycle_length=4))
```
WARNING: If `sloppy` is `True`, the order of produced elements is not
deterministic.
Args:
map_func: A function mapping a nested structure of tensors to a `Dataset`.
cycle_length: The number of threads to interleave from in parallel.
block_length: The number of consecutive elements to pull from a thread
before advancing to the next thread.
sloppy: If false, elements are produced in deterministic order. Otherwise,
the implementation is allowed, for the sake of expediency, to produce
elements in a non-deterministic order.
Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}.
"""
def _apply_fn(dataset):
return ParallelInterleaveDataset(
dataset, map_func, cycle_length, block_length, sloppy)
return _apply_fn
@deprecation.deprecated(
None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.")
def sloppy_interleave(map_func, cycle_length, block_length=1): def sloppy_interleave(map_func, cycle_length, block_length=1):
"""A non-deterministic version of the `Dataset.interleave()` transformation. """A non-deterministic version of the `Dataset.interleave()` transformation.
@ -132,6 +184,6 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
@{tf.data.Dataset.apply}. @{tf.data.Dataset.apply}.
""" """
def _apply_fn(dataset): def _apply_fn(dataset):
return SloppyInterleaveDataset( return ParallelInterleaveDataset(
dataset, map_func, cycle_length, block_length) dataset, map_func, cycle_length, block_length, sloppy=True)
return _apply_fn return _apply_fn

View File

@ -5924,8 +5924,8 @@ tf_kernel_library(
) )
tf_kernel_library( tf_kernel_library(
name = "sloppy_interleave_dataset_op", name = "parallel_interleave_dataset_op",
srcs = ["sloppy_interleave_dataset_op.cc"], srcs = ["parallel_interleave_dataset_op.cc"],
deps = [ deps = [
":captured_function", ":captured_function",
":dataset", ":dataset",
@ -6162,6 +6162,7 @@ tf_kernel_library(
":map_and_batch_dataset_op", ":map_and_batch_dataset_op",
":map_dataset_op", ":map_dataset_op",
":padded_batch_dataset_op", ":padded_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op", ":parallel_map_dataset_op",
":prefetch_dataset_op", ":prefetch_dataset_op",
":range_dataset_op", ":range_dataset_op",
@ -6170,7 +6171,6 @@ tf_kernel_library(
":scan_dataset_op", ":scan_dataset_op",
":shuffle_dataset_op", ":shuffle_dataset_op",
":skip_dataset_op", ":skip_dataset_op",
":sloppy_interleave_dataset_op",
":sparse_tensor_slice_dataset_op", ":sparse_tensor_slice_dataset_op",
":sql_dataset_ops", ":sql_dataset_ops",
":take_dataset_op", ":take_dataset_op",

View File

@ -336,7 +336,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector output_types_; const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_; const std::vector<PartialTensorShape> output_shapes_;
const std::unique_ptr<CapturedFunction> captured_func_; const std::unique_ptr<CapturedFunction> captured_func_;
const Eigen::ThreadPoolDevice* device_; // not owned const Eigen::ThreadPoolDevice* device_; // not owned
}; };
const int graph_def_version_; const int graph_def_version_;

View File

@ -17,12 +17,11 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/captured_function.h"
#include "tensorflow/core/kernels/dataset_utils.h" #include "tensorflow/core/kernels/dataset_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/kernels/captured_function.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -30,9 +29,9 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level // See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op. // description of the following op.
class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel { class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public: public:
explicit SloppyInterleaveDatasetOp(OpKernelConstruction* ctx) explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx), : UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) { graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
@ -62,13 +61,16 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES(ctx, block_length > 0, OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0")); errors::InvalidArgument("`block_length` must be > 0"));
bool sloppy;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
std::unique_ptr<CapturedFunction> captured_func; std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_, OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_,
std::move(other_arguments), std::move(other_arguments),
&captured_func)); &captured_func));
*output = new Dataset(input, std::move(captured_func), cycle_length, *output = new Dataset(input, std::move(captured_func), cycle_length,
block_length, output_types_, output_shapes_); block_length, sloppy, output_types_, output_shapes_);
} }
private: private:
@ -76,12 +78,13 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
public: public:
Dataset(const DatasetBase* input, Dataset(const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, const DataTypeVector& output_types, int64 block_length, bool sloppy, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes) const std::vector<PartialTensorShape>& output_shapes)
: input_(input), : input_(input),
captured_func_(std::move(captured_func)), captured_func_(std::move(captured_func)),
cycle_length_(cycle_length), cycle_length_(cycle_length),
block_length_(block_length), block_length_(block_length),
sloppy_(sloppy),
output_types_(output_types), output_types_(output_types),
output_shapes_(output_shapes) { output_shapes_(output_shapes) {
input_->Ref(); input_->Ref();
@ -91,8 +94,8 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIterator( std::unique_ptr<IteratorBase> MakeIterator(
const string& prefix) const override { const string& prefix) const override {
return std::unique_ptr<IteratorBase>( return std::unique_ptr<IteratorBase>(new Iterator(
new Iterator({this, strings::StrCat(prefix, "::SloppyInterleave")})); {this, strings::StrCat(prefix, "::ParallelInterleave")}));
} }
const DataTypeVector& output_dtypes() const override { const DataTypeVector& output_dtypes() const override {
@ -103,7 +106,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
} }
string DebugString() override { string DebugString() override {
return "SloppyInterleaveDatasetOp::Dataset"; return "ParallelInterleaveDatasetOp::Dataset";
} }
private: private:
@ -131,16 +134,24 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override { bool* end_of_sequence) override {
mutex_lock l(mu_); mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
// Search for available items, blocking if necessary. const int64 num_workers = worker_threads_.size();
if (num_workers == 0) {
*end_of_sequence = true;
return Status::OK();
}
while (!cancelled_) { while (!cancelled_) {
for (size_t i = 0; i < dataset()->cycle_length_; ++i) { // Wait for an item to become available, blocking if necessary. If we
size_t index = (next_index_ + i) % dataset()->cycle_length_; // are allowed to be sloppy, we can skip over input datasets that do
// not have an item readily available.
const int64 n = dataset()->sloppy_ ? num_workers : 1LL;
for (int64 i = 0; i < n; ++i) {
int64 index = (next_index_ + i) % num_workers;
if (output_elements_[index].is_produced) { if (output_elements_[index].is_produced) {
next_index_ = index; next_index_ = index;
if (i == 0) { if (i == 0) {
block_count_++; block_count_++;
if (block_count_ == dataset()->block_length_) { if (block_count_ == dataset()->block_length_) {
next_index_ = (index + 1) % dataset()->cycle_length_; next_index_ = (index + 1) % num_workers;
block_count_ = 0; block_count_ = 0;
} }
} else { } else {
@ -150,7 +161,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (output_elements_[index].end_of_sequence) { if (output_elements_[index].end_of_sequence) {
output_elements_[index].is_produced = false; output_elements_[index].is_produced = false;
output_elements_[index].cond_var.notify_one(); output_elements_[index].cond_var.notify_one();
next_index_ = (index + 1) % dataset()->cycle_length_; next_index_ = (index + 1) % num_workers;
block_count_ = 0; block_count_ = 0;
i = -1; // Restart the inner loop i = -1; // Restart the inner loop
continue; continue;
@ -174,11 +185,21 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
*end_of_sequence = true; *end_of_sequence = true;
return Status::OK(); return Status::OK();
} }
// If we are not allowed to be sloppy and
// `worker_threads_[next_index]` has finished, advance `next_index`.
if (!dataset()->sloppy_ && worker_threads_[next_index_].finished) {
next_index_ = (next_index_ + 1) % num_workers;
continue;
}
// No values available; wait until woken up. // No values available; wait until woken up.
// TODO(jsimsa): Use slot-specific condition variable for
// coordination of elements consumption.
cond_var_.wait(l); cond_var_.wait(l);
} }
return errors::Cancelled( return errors::Cancelled(
"SloppyInterleaveDatasetOp::Dataset::Iterator::GetNext"); "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
} }
private: private:
@ -201,6 +222,16 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
condition_variable cond_var; condition_variable cond_var;
}; };
struct ThreadStatus {
// The underlying thread uses `finished` to communicate to the producer
// that it has finished.
bool finished = false;
// The underlying thread object.
std::unique_ptr<Thread> thread;
explicit ThreadStatus(Thread* thread) : thread(thread) {}
};
Status EnsureWorkerThreadsStarted(IteratorContext* ctx) Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) { EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (worker_threads_.empty()) { if (worker_threads_.empty()) {
@ -220,11 +251,10 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> itr; std::unique_ptr<IteratorBase> itr;
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr)); ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr));
worker_threads_.emplace_back( worker_threads_.emplace_back(ctx->env()->StartThread(
std::unique_ptr<Thread>(ctx->env()->StartThread( {}, "worker_thread",
{}, "worker_thread", std::bind(&Iterator::WorkerThread, this,
std::bind(&Iterator::WorkerThread, this, new IteratorContext(*ctx), i, itr.release())));
new IteratorContext(*ctx), i, itr.release()))));
num_active_threads_ = i + 1; num_active_threads_ = i + 1;
} }
} }
@ -264,6 +294,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr); std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr);
auto cleanup = gtl::MakeCleanup([this, thread_index] { auto cleanup = gtl::MakeCleanup([this, thread_index] {
mutex_lock l(mu_); mutex_lock l(mu_);
worker_threads_[thread_index].finished = true;
num_active_threads_--; num_active_threads_--;
cond_var_.notify_all(); cond_var_.notify_all();
}); });
@ -345,13 +376,14 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Pointers to the worker threads. This must be last to ensure the // Pointers to the worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated. // threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads. // TODO(b/65178177): Avoid allocating additional threads.
std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_); std::vector<ThreadStatus> worker_threads_ GUARDED_BY(mu_);
}; };
const DatasetBase* const input_; const DatasetBase* const input_;
const std::unique_ptr<CapturedFunction> captured_func_; const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_; const int64 cycle_length_;
const int64 block_length_; const int64 block_length_;
const bool sloppy_;
const DataTypeVector output_types_; const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_; const std::vector<PartialTensorShape> output_shapes_;
}; };
@ -362,8 +394,8 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
NameAttrList func_; NameAttrList func_;
}; };
REGISTER_KERNEL_BUILDER(Name("SloppyInterleaveDataset").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
SloppyInterleaveDatasetOp); ParallelInterleaveDatasetOp);
} // namespace } // namespace

View File

@ -59,7 +59,6 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
Dataset(const DatasetBase* input, int64 buffer_size, Dataset(const DatasetBase* input, int64 buffer_size,
IteratorContext::Params ctx_params) IteratorContext::Params ctx_params)
: input_(input), : input_(input),
buffer_size_(buffer_size), buffer_size_(buffer_size),
ctx_params_(std::move(ctx_params)) { ctx_params_(std::move(ctx_params)) {
input_->Ref(); input_->Ref();

View File

@ -32629,95 +32629,6 @@ op {
} }
} }
} }
op {
name: "SloppyInterleaveDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "other_arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "cycle_length"
type: DT_INT64
}
input_arg {
name: "block_length"
type: DT_INT64
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "f"
type: "func"
}
attr {
name: "Targuments"
type: "list(type)"
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
is_stateful: true
}
op {
name: "SloppyInterleaveDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "other_arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "cycle_length"
type: DT_INT64
}
input_arg {
name: "block_length"
type: DT_INT64
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "f"
type: "func"
}
attr {
name: "Targuments"
type: "list(type)"
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
}
op { op {
name: "Softmax" name: "Softmax"
input_arg { input_arg {

View File

@ -285,11 +285,12 @@ f: A function mapping elements of `input_dataset`, concatenated with
`output_types` and `output_shapes`. `output_types` and `output_shapes`.
)doc"); )doc");
REGISTER_OP("SloppyInterleaveDataset") REGISTER_OP("ParallelInterleaveDataset")
.Input("input_dataset: variant") .Input("input_dataset: variant")
.Input("other_arguments: Targuments") .Input("other_arguments: Targuments")
.Input("cycle_length: int64") .Input("cycle_length: int64")
.Input("block_length: int64") .Input("block_length: int64")
.Input("sloppy: bool")
.Output("handle: variant") .Output("handle: variant")
.Attr("f: func") .Attr("f: func")
.Attr("Targuments: list(type) >= 0") .Attr("Targuments: list(type) >= 0")