Generalizing sloppy_interleave, making sloppiness an option.
PiperOrigin-RevId: 173687797
This commit is contained in:
parent
7775a66043
commit
6b05b36cd2
@ -50,6 +50,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element
|
|||||||
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
|
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
|
||||||
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
|
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
|
||||||
from tensorflow.contrib.data.python.ops.grouping import group_by_window
|
from tensorflow.contrib.data.python.ops.grouping import group_by_window
|
||||||
|
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
|
||||||
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
|
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
|
||||||
from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset
|
from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset
|
||||||
from tensorflow.contrib.data.python.ops.readers import read_batch_features
|
from tensorflow.contrib.data.python.ops.readers import read_batch_features
|
||||||
@ -57,7 +58,6 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset
|
|||||||
from tensorflow.contrib.data.python.ops.readers import TextLineDataset
|
from tensorflow.contrib.data.python.ops.readers import TextLineDataset
|
||||||
from tensorflow.contrib.data.python.ops.readers import TFRecordDataset
|
from tensorflow.contrib.data.python.ops.readers import TFRecordDataset
|
||||||
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
|
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
|
||||||
from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave
|
|
||||||
from tensorflow.python.data.ops.iterator_ops import Iterator
|
from tensorflow.python.data.ops.iterator_ops import Iterator
|
||||||
# pylint: enable=unused-import
|
# pylint: enable=unused-import
|
||||||
|
|
||||||
|
@ -143,6 +143,29 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "interleave_dataset_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["interleave_dataset_op_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = [
|
||||||
|
"manual", # b/67958761
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||||
|
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:client",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:script_ops",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "iterator_ops_cluster_test",
|
name = "iterator_ops_cluster_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -352,29 +375,6 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "sloppy_transformation_dataset_op_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["sloppy_transformation_dataset_op_test.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
tags = [
|
|
||||||
"manual", # b/67958761
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
|
||||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
|
||||||
"//tensorflow/python:array_ops",
|
|
||||||
"//tensorflow/python:client",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:dtypes",
|
|
||||||
"//tensorflow/python:errors",
|
|
||||||
"//tensorflow/python:math_ops",
|
|
||||||
"//tensorflow/python:script_ops",
|
|
||||||
"//tensorflow/python:training",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "sql_dataset_op_test",
|
name = "sql_dataset_op_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -25,7 +25,7 @@ import time
|
|||||||
from six.moves import zip_longest
|
from six.moves import zip_longest
|
||||||
|
|
||||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||||
from tensorflow.contrib.data.python.ops import sloppy_ops
|
from tensorflow.contrib.data.python.ops import interleave_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -34,12 +34,13 @@ from tensorflow.python.ops import script_ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class SloppyInterleaveDatasetTest(test.TestCase):
|
class ParallelInterleaveDatasetTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
|
self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
|
||||||
self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
|
self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
|
self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
|
||||||
|
self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
|
||||||
|
|
||||||
self.repeat_count = 2
|
self.repeat_count = 2
|
||||||
|
|
||||||
@ -69,9 +70,9 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
|
|
||||||
self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values)
|
self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values)
|
||||||
.repeat(self.repeat_count).apply(
|
.repeat(self.repeat_count).apply(
|
||||||
sloppy_ops.sloppy_interleave(
|
interleave_ops.parallel_interleave(
|
||||||
interleave_fn, self.cycle_length,
|
interleave_fn, self.cycle_length,
|
||||||
self.block_length)))
|
self.block_length, self.sloppy)))
|
||||||
self.iterator = self.dataset.make_initializable_iterator()
|
self.iterator = self.dataset.make_initializable_iterator()
|
||||||
self.init_op = self.iterator.initializer
|
self.init_op = self.iterator.initializer
|
||||||
self.next_element = self.iterator.get_next()
|
self.next_element = self.iterator.get_next()
|
||||||
@ -161,7 +162,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
for i in range(4, 7):
|
for i in range(4, 7):
|
||||||
self.write_coordination_events[i].set()
|
self.write_coordination_events[i].set()
|
||||||
|
|
||||||
def testSingleThreaded(self):
|
def _testSingleThreaded(self, sloppy=False):
|
||||||
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
|
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
|
||||||
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
|
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -171,7 +172,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [4, 5, 6],
|
self.input_values: [4, 5, 6],
|
||||||
self.cycle_length: 1,
|
self.cycle_length: 1,
|
||||||
self.block_length: 1
|
self.block_length: 1,
|
||||||
|
self.sloppy: sloppy
|
||||||
})
|
})
|
||||||
|
|
||||||
for expected_element in self._interleave(
|
for expected_element in self._interleave(
|
||||||
@ -182,7 +184,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(self.next_element)
|
sess.run(self.next_element)
|
||||||
|
|
||||||
def testTwoThreadsNoContention(self):
|
def testSingleThreaded(self):
|
||||||
|
self._testSingleThreaded()
|
||||||
|
|
||||||
|
def testSingleThreadedSloppy(self):
|
||||||
|
self._testSingleThreaded(sloppy=True)
|
||||||
|
|
||||||
|
def _testTwoThreadsNoContention(self, sloppy=False):
|
||||||
# num_threads > 1.
|
# num_threads > 1.
|
||||||
# Explicit coordination should result in `Dataset.interleave()` behavior
|
# Explicit coordination should result in `Dataset.interleave()` behavior
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -193,7 +201,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [4, 5, 6],
|
self.input_values: [4, 5, 6],
|
||||||
self.cycle_length: 2,
|
self.cycle_length: 2,
|
||||||
self.block_length: 1
|
self.block_length: 1,
|
||||||
|
self.sloppy: sloppy
|
||||||
})
|
})
|
||||||
for i, expected_element in enumerate(
|
for i, expected_element in enumerate(
|
||||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||||
@ -211,43 +220,59 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(self.next_element)
|
sess.run(self.next_element)
|
||||||
|
|
||||||
|
def testTwoThreadsNoContention(self):
|
||||||
|
self._testTwoThreadsNoContention()
|
||||||
|
|
||||||
|
def testTwoThreadsNoContentionSloppy(self):
|
||||||
|
self._testTwoThreadsNoContention(sloppy=True)
|
||||||
|
|
||||||
|
def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
|
||||||
|
"""Tests where all the workers race in producing elements.
|
||||||
|
|
||||||
|
Note: this is in contrast with the prevous test which carefully sequences
|
||||||
|
the execution of the map functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sloppy: Whether to be sloppy or not.
|
||||||
|
"""
|
||||||
|
with self.test_session() as sess:
|
||||||
|
self._clear_coordination_events()
|
||||||
|
done_first_event = False
|
||||||
|
sess.run(
|
||||||
|
self.init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.input_values: [4, 5, 6],
|
||||||
|
self.cycle_length: 2,
|
||||||
|
self.block_length: 1,
|
||||||
|
self.sloppy: sloppy,
|
||||||
|
})
|
||||||
|
for i, expected_element in enumerate(
|
||||||
|
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||||
|
1)):
|
||||||
|
if done_first_event: # First event starts the worker threads.
|
||||||
|
self._allow_all_map_threads()
|
||||||
|
self.read_coordination_events[expected_element].acquire()
|
||||||
|
else:
|
||||||
|
self.write_coordination_events[expected_element].set()
|
||||||
|
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
|
||||||
|
actual_element = sess.run(self.next_element)
|
||||||
|
if not done_first_event:
|
||||||
|
done_first_event = True
|
||||||
|
self.assertTrue(
|
||||||
|
self.read_coordination_events[expected_element].acquire(False))
|
||||||
|
self.assertEqual(expected_element * expected_element, actual_element,
|
||||||
|
"At index %s: %s expected, got: %s" %
|
||||||
|
(i, expected_element, actual_element))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(self.next_element)
|
||||||
|
|
||||||
def testTwoThreadsNoContentionWithRaces(self):
|
def testTwoThreadsNoContentionWithRaces(self):
|
||||||
"""Tests where all the workers race in producing elements.
|
self._testTwoThreadsNoContentionWithRaces()
|
||||||
|
|
||||||
Note: this is in contrast with the prevous test which carefully sequences
|
def testTwoThreadsNoContentionWithRacesSloppy(self):
|
||||||
the execution of the map functions.
|
self._testTwoThreadsNoContentionWithRaces(sloppy=True)
|
||||||
"""
|
|
||||||
with self.test_session() as sess:
|
|
||||||
self._clear_coordination_events()
|
|
||||||
done_first_event = False
|
|
||||||
sess.run(
|
|
||||||
self.init_op,
|
|
||||||
feed_dict={
|
|
||||||
self.input_values: [4, 5, 6],
|
|
||||||
self.cycle_length: 2,
|
|
||||||
self.block_length: 1
|
|
||||||
})
|
|
||||||
for i, expected_element in enumerate(
|
|
||||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
|
||||||
1)):
|
|
||||||
if done_first_event: # First event starts the worker threads.
|
|
||||||
self._allow_all_map_threads()
|
|
||||||
self.read_coordination_events[expected_element].acquire()
|
|
||||||
else:
|
|
||||||
self.write_coordination_events[expected_element].set()
|
|
||||||
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
|
|
||||||
actual_element = sess.run(self.next_element)
|
|
||||||
if not done_first_event:
|
|
||||||
done_first_event = True
|
|
||||||
self.assertTrue(
|
|
||||||
self.read_coordination_events[expected_element].acquire(False))
|
|
||||||
self.assertEqual(expected_element * expected_element, actual_element,
|
|
||||||
"At index %s: %s expected, got: %s" %
|
|
||||||
(i, expected_element, actual_element))
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
sess.run(self.next_element)
|
|
||||||
|
|
||||||
def testTwoThreadsNoContentionBlockLength(self):
|
def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
|
||||||
# num_threads > 1.
|
# num_threads > 1.
|
||||||
# Explicit coordination should result in `Dataset.interleave()` behavior
|
# Explicit coordination should result in `Dataset.interleave()` behavior
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -258,7 +283,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [4, 5, 6],
|
self.input_values: [4, 5, 6],
|
||||||
self.cycle_length: 2,
|
self.cycle_length: 2,
|
||||||
self.block_length: 2
|
self.block_length: 2,
|
||||||
|
self.sloppy: sloppy
|
||||||
})
|
})
|
||||||
for i, expected_element in enumerate(
|
for i, expected_element in enumerate(
|
||||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||||
@ -276,11 +302,21 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(self.next_element)
|
sess.run(self.next_element)
|
||||||
|
|
||||||
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
|
def testTwoThreadsNoContentionBlockLength(self):
|
||||||
|
self._testTwoThreadsNoContentionBlockLength()
|
||||||
|
|
||||||
|
def testTwoThreadsNoContentionBlockLengthSloppy(self):
|
||||||
|
self._testTwoThreadsNoContentionBlockLength(sloppy=True)
|
||||||
|
|
||||||
|
def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
|
||||||
"""Tests where all the workers race in producing elements.
|
"""Tests where all the workers race in producing elements.
|
||||||
|
|
||||||
Note: this is in contrast with the prevous test which carefully sequences
|
Note: this is in contrast with the prevous test which carefully sequences
|
||||||
the execution of the map functions.
|
the execution of the map functions.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sloppy: Whether to be sloppy or not.
|
||||||
"""
|
"""
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
self._clear_coordination_events()
|
self._clear_coordination_events()
|
||||||
@ -290,7 +326,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [4, 5, 6],
|
self.input_values: [4, 5, 6],
|
||||||
self.cycle_length: 2,
|
self.cycle_length: 2,
|
||||||
self.block_length: 2
|
self.block_length: 2,
|
||||||
|
self.sloppy: sloppy
|
||||||
})
|
})
|
||||||
for i, expected_element in enumerate(
|
for i, expected_element in enumerate(
|
||||||
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
|
||||||
@ -312,7 +349,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(self.next_element)
|
sess.run(self.next_element)
|
||||||
|
|
||||||
def testEmptyInput(self):
|
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
|
||||||
|
self._testTwoThreadsNoContentionWithRacesAndBlocking()
|
||||||
|
|
||||||
|
def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
|
||||||
|
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
|
||||||
|
|
||||||
|
def _testEmptyInput(self, sloppy=False):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
# Empty input.
|
# Empty input.
|
||||||
self._clear_coordination_events()
|
self._clear_coordination_events()
|
||||||
@ -321,12 +364,19 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [],
|
self.input_values: [],
|
||||||
self.cycle_length: 2,
|
self.cycle_length: 2,
|
||||||
self.block_length: 3
|
self.block_length: 3,
|
||||||
|
self.sloppy: sloppy
|
||||||
})
|
})
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(self.next_element)
|
sess.run(self.next_element)
|
||||||
|
|
||||||
def testNonEmptyInputIntoEmptyOutputs(self):
|
def testEmptyInput(self):
|
||||||
|
self._testEmptyInput()
|
||||||
|
|
||||||
|
def testEmptyInputSloppy(self):
|
||||||
|
self._testEmptyInput(sloppy=True)
|
||||||
|
|
||||||
|
def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
|
||||||
# Non-empty input leading to empty output.
|
# Non-empty input leading to empty output.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
self._clear_coordination_events()
|
self._clear_coordination_events()
|
||||||
@ -335,12 +385,19 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [0, 0, 0],
|
self.input_values: [0, 0, 0],
|
||||||
self.cycle_length: 2,
|
self.cycle_length: 2,
|
||||||
self.block_length: 3
|
self.block_length: 3,
|
||||||
|
self.sloppy: sloppy
|
||||||
})
|
})
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(self.next_element)
|
sess.run(self.next_element)
|
||||||
|
|
||||||
def testPartiallyEmptyOutputs(self):
|
def testNonEmptyInputIntoEmptyOutputs(self):
|
||||||
|
self._testNonEmptyInputIntoEmptyOutputs()
|
||||||
|
|
||||||
|
def testNonEmptyInputIntoEmptyOutputsSloppy(self):
|
||||||
|
self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
|
||||||
|
|
||||||
|
def _testPartiallyEmptyOutputs(self, sloppy=False):
|
||||||
# Mixture of non-empty and empty interleaved datasets.
|
# Mixture of non-empty and empty interleaved datasets.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
self._clear_coordination_events()
|
self._clear_coordination_events()
|
||||||
@ -350,7 +407,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [4, 0, 6],
|
self.input_values: [4, 0, 6],
|
||||||
self.cycle_length: 2,
|
self.cycle_length: 2,
|
||||||
self.block_length: 1
|
self.block_length: 1,
|
||||||
|
self.sloppy: sloppy,
|
||||||
})
|
})
|
||||||
for i, expected_element in enumerate(
|
for i, expected_element in enumerate(
|
||||||
self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
|
self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
|
||||||
@ -367,7 +425,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(self.next_element)
|
sess.run(self.next_element)
|
||||||
|
|
||||||
def testDelayedOutput(self):
|
def testPartiallyEmptyOutputs(self):
|
||||||
|
self._testPartiallyEmptyOutputs()
|
||||||
|
|
||||||
|
def testPartiallyEmptyOutputsSloppy(self):
|
||||||
|
self._testPartiallyEmptyOutputs(sloppy=True)
|
||||||
|
|
||||||
|
def testDelayedOutputSloppy(self):
|
||||||
# Explicitly control the sequence of events to ensure we correctly avoid
|
# Explicitly control the sequence of events to ensure we correctly avoid
|
||||||
# head-of-line blocking.
|
# head-of-line blocking.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -377,7 +441,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [4, 5, 6],
|
self.input_values: [4, 5, 6],
|
||||||
self.cycle_length: 2,
|
self.cycle_length: 2,
|
||||||
self.block_length: 1
|
self.block_length: 1,
|
||||||
|
self.sloppy: True,
|
||||||
})
|
})
|
||||||
|
|
||||||
mis_ordering = [
|
mis_ordering = [
|
||||||
@ -391,7 +456,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(self.next_element)
|
sess.run(self.next_element)
|
||||||
|
|
||||||
def testBlockLengthWithContention(self):
|
def testBlockLengthWithContentionSloppy(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
self._clear_coordination_events()
|
self._clear_coordination_events()
|
||||||
done_first_event = False
|
done_first_event = False
|
||||||
@ -400,7 +465,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [4, 5, 6],
|
self.input_values: [4, 5, 6],
|
||||||
self.cycle_length: 2,
|
self.cycle_length: 2,
|
||||||
self.block_length: 3
|
self.block_length: 3,
|
||||||
|
self.sloppy: True
|
||||||
})
|
})
|
||||||
# Test against a generating sequence that differs from the uncontended
|
# Test against a generating sequence that differs from the uncontended
|
||||||
# case, in order to prove sloppy correctness.
|
# case, in order to prove sloppy correctness.
|
||||||
@ -422,7 +488,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(self.next_element)
|
sess.run(self.next_element)
|
||||||
|
|
||||||
def testEarlyExit(self):
|
def _testEarlyExit(self, sloppy=False):
|
||||||
# Exiting without consuming all input should not block
|
# Exiting without consuming all input should not block
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
self._clear_coordination_events()
|
self._clear_coordination_events()
|
||||||
@ -431,7 +497,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
feed_dict={
|
feed_dict={
|
||||||
self.input_values: [4, 5, 6],
|
self.input_values: [4, 5, 6],
|
||||||
self.cycle_length: 3,
|
self.cycle_length: 3,
|
||||||
self.block_length: 2
|
self.block_length: 2,
|
||||||
|
self.sloppy: sloppy
|
||||||
})
|
})
|
||||||
for i in range(4, 7):
|
for i in range(4, 7):
|
||||||
self.write_coordination_events[i].set()
|
self.write_coordination_events[i].set()
|
||||||
@ -445,7 +512,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
self.read_coordination_events[i].acquire()
|
self.read_coordination_events[i].acquire()
|
||||||
self.write_coordination_events[i].set()
|
self.write_coordination_events[i].set()
|
||||||
|
|
||||||
def testTooManyReaders(self):
|
def testEarlyExit(self):
|
||||||
|
self._testEarlyExit()
|
||||||
|
|
||||||
|
def testEarlyExitSloppy(self):
|
||||||
|
self._testEarlyExit(sloppy=True)
|
||||||
|
|
||||||
|
def _testTooManyReaders(self, sloppy=False):
|
||||||
|
|
||||||
def interleave_fn(x):
|
def interleave_fn(x):
|
||||||
dataset = dataset_ops.Dataset.from_tensors(x)
|
dataset = dataset_ops.Dataset.from_tensors(x)
|
||||||
@ -455,8 +528,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
|
dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
|
||||||
dataset = dataset.repeat(self.repeat_count)
|
dataset = dataset.repeat(self.repeat_count)
|
||||||
dataset = dataset.apply(
|
dataset = dataset.apply(
|
||||||
sloppy_ops.sloppy_interleave(interleave_fn, cycle_length=16,
|
interleave_ops.parallel_interleave(
|
||||||
block_length=2))
|
interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
|
||||||
iterator = dataset.make_one_shot_iterator()
|
iterator = dataset.make_one_shot_iterator()
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -468,6 +541,11 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
|
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
|
||||||
self.assertItemsEqual(output_values, expected_values)
|
self.assertItemsEqual(output_values, expected_values)
|
||||||
|
|
||||||
|
def testTooManyReaders(self):
|
||||||
|
self._testTooManyReaders()
|
||||||
|
|
||||||
|
def testTooManyReadersSloppy(self):
|
||||||
|
self._testTooManyReaders(sloppy=True)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
@ -60,9 +60,9 @@ py_library(
|
|||||||
"enumerate_ops.py",
|
"enumerate_ops.py",
|
||||||
"error_ops.py",
|
"error_ops.py",
|
||||||
"grouping.py",
|
"grouping.py",
|
||||||
|
"interleave_ops.py",
|
||||||
"resampling.py",
|
"resampling.py",
|
||||||
"scan_ops.py",
|
"scan_ops.py",
|
||||||
"sloppy_ops.py",
|
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
@ -23,14 +23,16 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_dataset_ops
|
from tensorflow.python.ops import gen_dataset_ops
|
||||||
|
from tensorflow.python.util import deprecation
|
||||||
|
|
||||||
|
|
||||||
class SloppyInterleaveDataset(dataset_ops.Dataset):
|
class ParallelInterleaveDataset(dataset_ops.Dataset):
|
||||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||||
|
|
||||||
def __init__(self, input_dataset, map_func, cycle_length, block_length):
|
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||||
"""See `tf.contrib.data.sloppy_interleave()` for details."""
|
sloppy):
|
||||||
super(SloppyInterleaveDataset, self).__init__()
|
"""See `tf.contrib.data.parallel_interleave()` for details."""
|
||||||
|
super(ParallelInterleaveDataset, self).__init__()
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
|
|
||||||
@function.Defun(*nest.flatten(input_dataset.output_types))
|
@function.Defun(*nest.flatten(input_dataset.output_types))
|
||||||
@ -62,13 +64,16 @@ class SloppyInterleaveDataset(dataset_ops.Dataset):
|
|||||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||||
self._block_length = ops.convert_to_tensor(
|
self._block_length = ops.convert_to_tensor(
|
||||||
block_length, dtype=dtypes.int64, name="block_length")
|
block_length, dtype=dtypes.int64, name="block_length")
|
||||||
|
self._sloppy = ops.convert_to_tensor(
|
||||||
|
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||||
|
|
||||||
def _as_variant_tensor(self):
|
def _as_variant_tensor(self):
|
||||||
return gen_dataset_ops.sloppy_interleave_dataset(
|
return gen_dataset_ops.parallel_interleave_dataset(
|
||||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||||
self._map_func.captured_inputs,
|
self._map_func.captured_inputs,
|
||||||
self._cycle_length,
|
self._cycle_length,
|
||||||
self._block_length,
|
self._block_length,
|
||||||
|
self._sloppy,
|
||||||
f=self._map_func,
|
f=self._map_func,
|
||||||
output_types=nest.flatten(self.output_types),
|
output_types=nest.flatten(self.output_types),
|
||||||
output_shapes=nest.flatten(self.output_shapes))
|
output_shapes=nest.flatten(self.output_shapes))
|
||||||
@ -82,6 +87,53 @@ class SloppyInterleaveDataset(dataset_ops.Dataset):
|
|||||||
return self._output_types
|
return self._output_types
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
|
||||||
|
"""A parallel version of the `Dataset.interleave()` transformation.
|
||||||
|
|
||||||
|
`parallel_interleave()` maps `map_func` across its input to produce nested
|
||||||
|
datasets, and outputs their elements interleaved. Unlike
|
||||||
|
@{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested
|
||||||
|
datasets in parallel, which increases the throughput, especially in the
|
||||||
|
presence of stragglers. Furthermore, the `sloppy` argument can be used to
|
||||||
|
improve performance, by relaxing the requirement that the outputs are produced
|
||||||
|
in a deterministic order, and allowing the implementation to skip over nested
|
||||||
|
datasets whose elements are not readily available when requested.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Preprocess 4 files concurrently.
|
||||||
|
filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
|
||||||
|
dataset = filenames.apply(
|
||||||
|
tf.contrib.data.parallel_interleave(
|
||||||
|
lambda filename: tf.data.TFRecordDataset(filename),
|
||||||
|
cycle_length=4))
|
||||||
|
```
|
||||||
|
|
||||||
|
WARNING: If `sloppy` is `True`, the order of produced elements is not
|
||||||
|
deterministic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
map_func: A function mapping a nested structure of tensors to a `Dataset`.
|
||||||
|
cycle_length: The number of threads to interleave from in parallel.
|
||||||
|
block_length: The number of consecutive elements to pull from a thread
|
||||||
|
before advancing to the next thread.
|
||||||
|
sloppy: If false, elements are produced in deterministic order. Otherwise,
|
||||||
|
the implementation is allowed, for the sake of expediency, to produce
|
||||||
|
elements in a non-deterministic order.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Dataset` transformation function, which can be passed to
|
||||||
|
@{tf.data.Dataset.apply}.
|
||||||
|
"""
|
||||||
|
def _apply_fn(dataset):
|
||||||
|
return ParallelInterleaveDataset(
|
||||||
|
dataset, map_func, cycle_length, block_length, sloppy)
|
||||||
|
return _apply_fn
|
||||||
|
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.")
|
||||||
def sloppy_interleave(map_func, cycle_length, block_length=1):
|
def sloppy_interleave(map_func, cycle_length, block_length=1):
|
||||||
"""A non-deterministic version of the `Dataset.interleave()` transformation.
|
"""A non-deterministic version of the `Dataset.interleave()` transformation.
|
||||||
|
|
||||||
@ -132,6 +184,6 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
|
|||||||
@{tf.data.Dataset.apply}.
|
@{tf.data.Dataset.apply}.
|
||||||
"""
|
"""
|
||||||
def _apply_fn(dataset):
|
def _apply_fn(dataset):
|
||||||
return SloppyInterleaveDataset(
|
return ParallelInterleaveDataset(
|
||||||
dataset, map_func, cycle_length, block_length)
|
dataset, map_func, cycle_length, block_length, sloppy=True)
|
||||||
return _apply_fn
|
return _apply_fn
|
@ -5924,8 +5924,8 @@ tf_kernel_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "sloppy_interleave_dataset_op",
|
name = "parallel_interleave_dataset_op",
|
||||||
srcs = ["sloppy_interleave_dataset_op.cc"],
|
srcs = ["parallel_interleave_dataset_op.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":captured_function",
|
":captured_function",
|
||||||
":dataset",
|
":dataset",
|
||||||
@ -6162,6 +6162,7 @@ tf_kernel_library(
|
|||||||
":map_and_batch_dataset_op",
|
":map_and_batch_dataset_op",
|
||||||
":map_dataset_op",
|
":map_dataset_op",
|
||||||
":padded_batch_dataset_op",
|
":padded_batch_dataset_op",
|
||||||
|
":parallel_interleave_dataset_op",
|
||||||
":parallel_map_dataset_op",
|
":parallel_map_dataset_op",
|
||||||
":prefetch_dataset_op",
|
":prefetch_dataset_op",
|
||||||
":range_dataset_op",
|
":range_dataset_op",
|
||||||
@ -6170,7 +6171,6 @@ tf_kernel_library(
|
|||||||
":scan_dataset_op",
|
":scan_dataset_op",
|
||||||
":shuffle_dataset_op",
|
":shuffle_dataset_op",
|
||||||
":skip_dataset_op",
|
":skip_dataset_op",
|
||||||
":sloppy_interleave_dataset_op",
|
|
||||||
":sparse_tensor_slice_dataset_op",
|
":sparse_tensor_slice_dataset_op",
|
||||||
":sql_dataset_ops",
|
":sql_dataset_ops",
|
||||||
":take_dataset_op",
|
":take_dataset_op",
|
||||||
|
@ -336,7 +336,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
const DataTypeVector output_types_;
|
const DataTypeVector output_types_;
|
||||||
const std::vector<PartialTensorShape> output_shapes_;
|
const std::vector<PartialTensorShape> output_shapes_;
|
||||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||||
const Eigen::ThreadPoolDevice* device_; // not owned
|
const Eigen::ThreadPoolDevice* device_; // not owned
|
||||||
};
|
};
|
||||||
|
|
||||||
const int graph_def_version_;
|
const int graph_def_version_;
|
||||||
|
@ -17,12 +17,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/captured_function.h"
|
||||||
#include "tensorflow/core/kernels/dataset_utils.h"
|
#include "tensorflow/core/kernels/dataset_utils.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/captured_function.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -30,9 +29,9 @@ namespace {
|
|||||||
// See documentation in ../ops/dataset_ops.cc for a high-level
|
// See documentation in ../ops/dataset_ops.cc for a high-level
|
||||||
// description of the following op.
|
// description of the following op.
|
||||||
|
|
||||||
class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit SloppyInterleaveDatasetOp(OpKernelConstruction* ctx)
|
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
|
||||||
: UnaryDatasetOpKernel(ctx),
|
: UnaryDatasetOpKernel(ctx),
|
||||||
graph_def_version_(ctx->graph_def_version()) {
|
graph_def_version_(ctx->graph_def_version()) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
|
||||||
@ -62,13 +61,16 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
OP_REQUIRES(ctx, block_length > 0,
|
OP_REQUIRES(ctx, block_length > 0,
|
||||||
errors::InvalidArgument("`block_length` must be > 0"));
|
errors::InvalidArgument("`block_length` must be > 0"));
|
||||||
|
|
||||||
|
bool sloppy;
|
||||||
|
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
|
||||||
|
|
||||||
std::unique_ptr<CapturedFunction> captured_func;
|
std::unique_ptr<CapturedFunction> captured_func;
|
||||||
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_,
|
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_,
|
||||||
std::move(other_arguments),
|
std::move(other_arguments),
|
||||||
&captured_func));
|
&captured_func));
|
||||||
|
|
||||||
*output = new Dataset(input, std::move(captured_func), cycle_length,
|
*output = new Dataset(input, std::move(captured_func), cycle_length,
|
||||||
block_length, output_types_, output_shapes_);
|
block_length, sloppy, output_types_, output_shapes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -76,12 +78,13 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
public:
|
public:
|
||||||
Dataset(const DatasetBase* input,
|
Dataset(const DatasetBase* input,
|
||||||
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
||||||
int64 block_length, const DataTypeVector& output_types,
|
int64 block_length, bool sloppy, const DataTypeVector& output_types,
|
||||||
const std::vector<PartialTensorShape>& output_shapes)
|
const std::vector<PartialTensorShape>& output_shapes)
|
||||||
: input_(input),
|
: input_(input),
|
||||||
captured_func_(std::move(captured_func)),
|
captured_func_(std::move(captured_func)),
|
||||||
cycle_length_(cycle_length),
|
cycle_length_(cycle_length),
|
||||||
block_length_(block_length),
|
block_length_(block_length),
|
||||||
|
sloppy_(sloppy),
|
||||||
output_types_(output_types),
|
output_types_(output_types),
|
||||||
output_shapes_(output_shapes) {
|
output_shapes_(output_shapes) {
|
||||||
input_->Ref();
|
input_->Ref();
|
||||||
@ -91,8 +94,8 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
std::unique_ptr<IteratorBase> MakeIterator(
|
std::unique_ptr<IteratorBase> MakeIterator(
|
||||||
const string& prefix) const override {
|
const string& prefix) const override {
|
||||||
return std::unique_ptr<IteratorBase>(
|
return std::unique_ptr<IteratorBase>(new Iterator(
|
||||||
new Iterator({this, strings::StrCat(prefix, "::SloppyInterleave")}));
|
{this, strings::StrCat(prefix, "::ParallelInterleave")}));
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override {
|
const DataTypeVector& output_dtypes() const override {
|
||||||
@ -103,7 +106,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string DebugString() override {
|
string DebugString() override {
|
||||||
return "SloppyInterleaveDatasetOp::Dataset";
|
return "ParallelInterleaveDatasetOp::Dataset";
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -131,16 +134,24 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
|
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
|
||||||
// Search for available items, blocking if necessary.
|
const int64 num_workers = worker_threads_.size();
|
||||||
|
if (num_workers == 0) {
|
||||||
|
*end_of_sequence = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
while (!cancelled_) {
|
while (!cancelled_) {
|
||||||
for (size_t i = 0; i < dataset()->cycle_length_; ++i) {
|
// Wait for an item to become available, blocking if necessary. If we
|
||||||
size_t index = (next_index_ + i) % dataset()->cycle_length_;
|
// are allowed to be sloppy, we can skip over input datasets that do
|
||||||
|
// not have an item readily available.
|
||||||
|
const int64 n = dataset()->sloppy_ ? num_workers : 1LL;
|
||||||
|
for (int64 i = 0; i < n; ++i) {
|
||||||
|
int64 index = (next_index_ + i) % num_workers;
|
||||||
if (output_elements_[index].is_produced) {
|
if (output_elements_[index].is_produced) {
|
||||||
next_index_ = index;
|
next_index_ = index;
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
block_count_++;
|
block_count_++;
|
||||||
if (block_count_ == dataset()->block_length_) {
|
if (block_count_ == dataset()->block_length_) {
|
||||||
next_index_ = (index + 1) % dataset()->cycle_length_;
|
next_index_ = (index + 1) % num_workers;
|
||||||
block_count_ = 0;
|
block_count_ = 0;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -150,7 +161,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
if (output_elements_[index].end_of_sequence) {
|
if (output_elements_[index].end_of_sequence) {
|
||||||
output_elements_[index].is_produced = false;
|
output_elements_[index].is_produced = false;
|
||||||
output_elements_[index].cond_var.notify_one();
|
output_elements_[index].cond_var.notify_one();
|
||||||
next_index_ = (index + 1) % dataset()->cycle_length_;
|
next_index_ = (index + 1) % num_workers;
|
||||||
block_count_ = 0;
|
block_count_ = 0;
|
||||||
i = -1; // Restart the inner loop
|
i = -1; // Restart the inner loop
|
||||||
continue;
|
continue;
|
||||||
@ -174,11 +185,21 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
*end_of_sequence = true;
|
*end_of_sequence = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we are not allowed to be sloppy and
|
||||||
|
// `worker_threads_[next_index]` has finished, advance `next_index`.
|
||||||
|
if (!dataset()->sloppy_ && worker_threads_[next_index_].finished) {
|
||||||
|
next_index_ = (next_index_ + 1) % num_workers;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// No values available; wait until woken up.
|
// No values available; wait until woken up.
|
||||||
|
// TODO(jsimsa): Use slot-specific condition variable for
|
||||||
|
// coordination of elements consumption.
|
||||||
cond_var_.wait(l);
|
cond_var_.wait(l);
|
||||||
}
|
}
|
||||||
return errors::Cancelled(
|
return errors::Cancelled(
|
||||||
"SloppyInterleaveDatasetOp::Dataset::Iterator::GetNext");
|
"ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -201,6 +222,16 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
condition_variable cond_var;
|
condition_variable cond_var;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ThreadStatus {
|
||||||
|
// The underlying thread uses `finished` to communicate to the producer
|
||||||
|
// that it has finished.
|
||||||
|
bool finished = false;
|
||||||
|
// The underlying thread object.
|
||||||
|
std::unique_ptr<Thread> thread;
|
||||||
|
|
||||||
|
explicit ThreadStatus(Thread* thread) : thread(thread) {}
|
||||||
|
};
|
||||||
|
|
||||||
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
|
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
if (worker_threads_.empty()) {
|
if (worker_threads_.empty()) {
|
||||||
@ -220,11 +251,10 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
std::unique_ptr<IteratorBase> itr;
|
std::unique_ptr<IteratorBase> itr;
|
||||||
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
|
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
|
||||||
ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr));
|
ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr));
|
||||||
worker_threads_.emplace_back(
|
worker_threads_.emplace_back(ctx->env()->StartThread(
|
||||||
std::unique_ptr<Thread>(ctx->env()->StartThread(
|
{}, "worker_thread",
|
||||||
{}, "worker_thread",
|
std::bind(&Iterator::WorkerThread, this,
|
||||||
std::bind(&Iterator::WorkerThread, this,
|
new IteratorContext(*ctx), i, itr.release())));
|
||||||
new IteratorContext(*ctx), i, itr.release()))));
|
|
||||||
num_active_threads_ = i + 1;
|
num_active_threads_ = i + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -264,6 +294,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr);
|
std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr);
|
||||||
auto cleanup = gtl::MakeCleanup([this, thread_index] {
|
auto cleanup = gtl::MakeCleanup([this, thread_index] {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
worker_threads_[thread_index].finished = true;
|
||||||
num_active_threads_--;
|
num_active_threads_--;
|
||||||
cond_var_.notify_all();
|
cond_var_.notify_all();
|
||||||
});
|
});
|
||||||
@ -345,13 +376,14 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
// Pointers to the worker threads. This must be last to ensure the
|
// Pointers to the worker threads. This must be last to ensure the
|
||||||
// threads have exited before any other members are deallocated.
|
// threads have exited before any other members are deallocated.
|
||||||
// TODO(b/65178177): Avoid allocating additional threads.
|
// TODO(b/65178177): Avoid allocating additional threads.
|
||||||
std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
|
std::vector<ThreadStatus> worker_threads_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
const DatasetBase* const input_;
|
const DatasetBase* const input_;
|
||||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||||
const int64 cycle_length_;
|
const int64 cycle_length_;
|
||||||
const int64 block_length_;
|
const int64 block_length_;
|
||||||
|
const bool sloppy_;
|
||||||
const DataTypeVector output_types_;
|
const DataTypeVector output_types_;
|
||||||
const std::vector<PartialTensorShape> output_shapes_;
|
const std::vector<PartialTensorShape> output_shapes_;
|
||||||
};
|
};
|
||||||
@ -362,8 +394,8 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
NameAttrList func_;
|
NameAttrList func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("SloppyInterleaveDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
|
||||||
SloppyInterleaveDatasetOp);
|
ParallelInterleaveDatasetOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
@ -59,7 +59,6 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
Dataset(const DatasetBase* input, int64 buffer_size,
|
Dataset(const DatasetBase* input, int64 buffer_size,
|
||||||
IteratorContext::Params ctx_params)
|
IteratorContext::Params ctx_params)
|
||||||
: input_(input),
|
: input_(input),
|
||||||
|
|
||||||
buffer_size_(buffer_size),
|
buffer_size_(buffer_size),
|
||||||
ctx_params_(std::move(ctx_params)) {
|
ctx_params_(std::move(ctx_params)) {
|
||||||
input_->Ref();
|
input_->Ref();
|
||||||
|
@ -32629,95 +32629,6 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
op {
|
|
||||||
name: "SloppyInterleaveDataset"
|
|
||||||
input_arg {
|
|
||||||
name: "input_dataset"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "other_arguments"
|
|
||||||
type_list_attr: "Targuments"
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "cycle_length"
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "block_length"
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "handle"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "f"
|
|
||||||
type: "func"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "Targuments"
|
|
||||||
type: "list(type)"
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "output_types"
|
|
||||||
type: "list(type)"
|
|
||||||
has_minimum: true
|
|
||||||
minimum: 1
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "output_shapes"
|
|
||||||
type: "list(shape)"
|
|
||||||
has_minimum: true
|
|
||||||
minimum: 1
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
||||||
op {
|
|
||||||
name: "SloppyInterleaveDataset"
|
|
||||||
input_arg {
|
|
||||||
name: "input_dataset"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "other_arguments"
|
|
||||||
type_list_attr: "Targuments"
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "cycle_length"
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "block_length"
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "handle"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "f"
|
|
||||||
type: "func"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "Targuments"
|
|
||||||
type: "list(type)"
|
|
||||||
has_minimum: true
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "output_types"
|
|
||||||
type: "list(type)"
|
|
||||||
has_minimum: true
|
|
||||||
minimum: 1
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "output_shapes"
|
|
||||||
type: "list(shape)"
|
|
||||||
has_minimum: true
|
|
||||||
minimum: 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
op {
|
op {
|
||||||
name: "Softmax"
|
name: "Softmax"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -285,11 +285,12 @@ f: A function mapping elements of `input_dataset`, concatenated with
|
|||||||
`output_types` and `output_shapes`.
|
`output_types` and `output_shapes`.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("SloppyInterleaveDataset")
|
REGISTER_OP("ParallelInterleaveDataset")
|
||||||
.Input("input_dataset: variant")
|
.Input("input_dataset: variant")
|
||||||
.Input("other_arguments: Targuments")
|
.Input("other_arguments: Targuments")
|
||||||
.Input("cycle_length: int64")
|
.Input("cycle_length: int64")
|
||||||
.Input("block_length: int64")
|
.Input("block_length: int64")
|
||||||
|
.Input("sloppy: bool")
|
||||||
.Output("handle: variant")
|
.Output("handle: variant")
|
||||||
.Attr("f: func")
|
.Attr("f: func")
|
||||||
.Attr("Targuments: list(type) >= 0")
|
.Attr("Targuments: list(type) >= 0")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user