Generalizing sloppy_interleave, making sloppiness an option.

PiperOrigin-RevId: 173687797
This commit is contained in:
Jiri Simsa 2017-10-27 10:29:36 -07:00 committed by TensorFlower Gardener
parent 7775a66043
commit 6b05b36cd2
11 changed files with 283 additions and 210 deletions

View File

@ -50,6 +50,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element
from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
from tensorflow.contrib.data.python.ops.grouping import group_by_window
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features
@ -57,7 +58,6 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset
from tensorflow.contrib.data.python.ops.readers import TextLineDataset
from tensorflow.contrib.data.python.ops.readers import TFRecordDataset
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave
from tensorflow.python.data.ops.iterator_ops import Iterator
# pylint: enable=unused-import

View File

@ -143,6 +143,29 @@ py_test(
],
)
py_test(
name = "interleave_dataset_op_test",
size = "small",
srcs = ["interleave_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual", # b/67958761
],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:training",
"//third_party/py/numpy",
],
)
py_test(
name = "iterator_ops_cluster_test",
size = "small",
@ -352,29 +375,6 @@ py_test(
],
)
py_test(
name = "sloppy_transformation_dataset_op_test",
size = "small",
srcs = ["sloppy_transformation_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual", # b/67958761
],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:training",
"//third_party/py/numpy",
],
)
py_test(
name = "sql_dataset_op_test",
size = "small",

View File

@ -25,7 +25,7 @@ import time
from six.moves import zip_longest
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import sloppy_ops
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
@ -34,12 +34,13 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
class SloppyInterleaveDatasetTest(test.TestCase):
class ParallelInterleaveDatasetTest(test.TestCase):
def setUp(self):
self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
self.repeat_count = 2
@ -69,9 +70,9 @@ class SloppyInterleaveDatasetTest(test.TestCase):
self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values)
.repeat(self.repeat_count).apply(
sloppy_ops.sloppy_interleave(
interleave_ops.parallel_interleave(
interleave_fn, self.cycle_length,
self.block_length)))
self.block_length, self.sloppy)))
self.iterator = self.dataset.make_initializable_iterator()
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
@ -161,7 +162,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
for i in range(4, 7):
self.write_coordination_events[i].set()
def testSingleThreaded(self):
def _testSingleThreaded(self, sloppy=False):
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
with self.test_session() as sess:
@ -171,7 +172,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 1,
self.block_length: 1
self.block_length: 1,
self.sloppy: sloppy
})
for expected_element in self._interleave(
@ -182,7 +184,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testTwoThreadsNoContention(self):
def testSingleThreaded(self):
self._testSingleThreaded()
def testSingleThreadedSloppy(self):
self._testSingleThreaded(sloppy=True)
def _testTwoThreadsNoContention(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
with self.test_session() as sess:
@ -193,7 +201,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1
self.block_length: 1,
self.sloppy: sloppy
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -211,43 +220,59 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testTwoThreadsNoContention(self):
self._testTwoThreadsNoContention()
def testTwoThreadsNoContentionSloppy(self):
self._testTwoThreadsNoContention(sloppy=True)
def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
"""Tests where all the workers race in producing elements.
Note: this is in contrast with the prevous test which carefully sequences
the execution of the map functions.
Args:
sloppy: Whether to be sloppy or not.
"""
with self.test_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
self.init_op,
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: sloppy,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
1)):
if done_first_event: # First event starts the worker threads.
self._allow_all_map_threads()
self.read_coordination_events[expected_element].acquire()
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
if not done_first_event:
done_first_event = True
self.assertTrue(
self.read_coordination_events[expected_element].acquire(False))
self.assertEqual(expected_element * expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testTwoThreadsNoContentionWithRaces(self):
"""Tests where all the workers race in producing elements.
self._testTwoThreadsNoContentionWithRaces()
Note: this is in contrast with the prevous test which carefully sequences
the execution of the map functions.
"""
with self.test_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
self.init_op,
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
1)):
if done_first_event: # First event starts the worker threads.
self._allow_all_map_threads()
self.read_coordination_events[expected_element].acquire()
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
if not done_first_event:
done_first_event = True
self.assertTrue(
self.read_coordination_events[expected_element].acquire(False))
self.assertEqual(expected_element * expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testTwoThreadsNoContentionWithRacesSloppy(self):
self._testTwoThreadsNoContentionWithRaces(sloppy=True)
def testTwoThreadsNoContentionBlockLength(self):
def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
with self.test_session() as sess:
@ -258,7 +283,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 2
self.block_length: 2,
self.sloppy: sloppy
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -276,11 +302,21 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
def testTwoThreadsNoContentionBlockLength(self):
self._testTwoThreadsNoContentionBlockLength()
def testTwoThreadsNoContentionBlockLengthSloppy(self):
self._testTwoThreadsNoContentionBlockLength(sloppy=True)
def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
"""Tests where all the workers race in producing elements.
Note: this is in contrast with the prevous test which carefully sequences
the execution of the map functions.
Args:
sloppy: Whether to be sloppy or not.
"""
with self.test_session() as sess:
self._clear_coordination_events()
@ -290,7 +326,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 2
self.block_length: 2,
self.sloppy: sloppy
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -312,7 +349,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testEmptyInput(self):
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking()
def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
def _testEmptyInput(self, sloppy=False):
with self.test_session() as sess:
# Empty input.
self._clear_coordination_events()
@ -321,12 +364,19 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [],
self.cycle_length: 2,
self.block_length: 3
self.block_length: 3,
self.sloppy: sloppy
})
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testNonEmptyInputIntoEmptyOutputs(self):
def testEmptyInput(self):
self._testEmptyInput()
def testEmptyInputSloppy(self):
self._testEmptyInput(sloppy=True)
def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
# Non-empty input leading to empty output.
with self.test_session() as sess:
self._clear_coordination_events()
@ -335,12 +385,19 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [0, 0, 0],
self.cycle_length: 2,
self.block_length: 3
self.block_length: 3,
self.sloppy: sloppy
})
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testPartiallyEmptyOutputs(self):
def testNonEmptyInputIntoEmptyOutputs(self):
self._testNonEmptyInputIntoEmptyOutputs()
def testNonEmptyInputIntoEmptyOutputsSloppy(self):
self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
def _testPartiallyEmptyOutputs(self, sloppy=False):
# Mixture of non-empty and empty interleaved datasets.
with self.test_session() as sess:
self._clear_coordination_events()
@ -350,7 +407,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [4, 0, 6],
self.cycle_length: 2,
self.block_length: 1
self.block_length: 1,
self.sloppy: sloppy,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
@ -367,7 +425,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testDelayedOutput(self):
def testPartiallyEmptyOutputs(self):
self._testPartiallyEmptyOutputs()
def testPartiallyEmptyOutputsSloppy(self):
self._testPartiallyEmptyOutputs(sloppy=True)
def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid
# head-of-line blocking.
with self.test_session() as sess:
@ -377,7 +441,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1
self.block_length: 1,
self.sloppy: True,
})
mis_ordering = [
@ -391,7 +456,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testBlockLengthWithContention(self):
def testBlockLengthWithContentionSloppy(self):
with self.test_session() as sess:
self._clear_coordination_events()
done_first_event = False
@ -400,7 +465,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 3
self.block_length: 3,
self.sloppy: True
})
# Test against a generating sequence that differs from the uncontended
# case, in order to prove sloppy correctness.
@ -422,7 +488,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testEarlyExit(self):
def _testEarlyExit(self, sloppy=False):
# Exiting without consuming all input should not block
with self.test_session() as sess:
self._clear_coordination_events()
@ -431,7 +497,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 3,
self.block_length: 2
self.block_length: 2,
self.sloppy: sloppy
})
for i in range(4, 7):
self.write_coordination_events[i].set()
@ -445,7 +512,13 @@ class SloppyInterleaveDatasetTest(test.TestCase):
self.read_coordination_events[i].acquire()
self.write_coordination_events[i].set()
def testTooManyReaders(self):
def testEarlyExit(self):
self._testEarlyExit()
def testEarlyExitSloppy(self):
self._testEarlyExit(sloppy=True)
def _testTooManyReaders(self, sloppy=False):
def interleave_fn(x):
dataset = dataset_ops.Dataset.from_tensors(x)
@ -455,8 +528,8 @@ class SloppyInterleaveDatasetTest(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
dataset = dataset.repeat(self.repeat_count)
dataset = dataset.apply(
sloppy_ops.sloppy_interleave(interleave_fn, cycle_length=16,
block_length=2))
interleave_ops.parallel_interleave(
interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
iterator = dataset.make_one_shot_iterator()
with self.test_session() as sess:
@ -468,6 +541,11 @@ class SloppyInterleaveDatasetTest(test.TestCase):
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
self.assertItemsEqual(output_values, expected_values)
def testTooManyReaders(self):
self._testTooManyReaders()
def testTooManyReadersSloppy(self):
self._testTooManyReaders(sloppy=True)
if __name__ == "__main__":
test.main()

View File

@ -60,9 +60,9 @@ py_library(
"enumerate_ops.py",
"error_ops.py",
"grouping.py",
"interleave_ops.py",
"resampling.py",
"scan_ops.py",
"sloppy_ops.py",
],
srcs_version = "PY2AND3",
deps = [

View File

@ -23,14 +23,16 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
class SloppyInterleaveDataset(dataset_ops.Dataset):
class ParallelInterleaveDataset(dataset_ops.Dataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func, cycle_length, block_length):
"""See `tf.contrib.data.sloppy_interleave()` for details."""
super(SloppyInterleaveDataset, self).__init__()
def __init__(self, input_dataset, map_func, cycle_length, block_length,
sloppy):
"""See `tf.contrib.data.parallel_interleave()` for details."""
super(ParallelInterleaveDataset, self).__init__()
self._input_dataset = input_dataset
@function.Defun(*nest.flatten(input_dataset.output_types))
@ -62,13 +64,16 @@ class SloppyInterleaveDataset(dataset_ops.Dataset):
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length")
self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
def _as_variant_tensor(self):
return gen_dataset_ops.sloppy_interleave_dataset(
return gen_dataset_ops.parallel_interleave_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._map_func.captured_inputs,
self._cycle_length,
self._block_length,
self._sloppy,
f=self._map_func,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
@ -82,6 +87,53 @@ class SloppyInterleaveDataset(dataset_ops.Dataset):
return self._output_types
def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
"""A parallel version of the `Dataset.interleave()` transformation.
`parallel_interleave()` maps `map_func` across its input to produce nested
datasets, and outputs their elements interleaved. Unlike
@{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested
datasets in parallel, which increases the throughput, especially in the
presence of stragglers. Furthermore, the `sloppy` argument can be used to
improve performance, by relaxing the requirement that the outputs are produced
in a deterministic order, and allowing the implementation to skip over nested
datasets whose elements are not readily available when requested.
Example usage:
```python
# Preprocess 4 files concurrently.
filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
dataset = filenames.apply(
tf.contrib.data.parallel_interleave(
lambda filename: tf.data.TFRecordDataset(filename),
cycle_length=4))
```
WARNING: If `sloppy` is `True`, the order of produced elements is not
deterministic.
Args:
map_func: A function mapping a nested structure of tensors to a `Dataset`.
cycle_length: The number of threads to interleave from in parallel.
block_length: The number of consecutive elements to pull from a thread
before advancing to the next thread.
sloppy: If false, elements are produced in deterministic order. Otherwise,
the implementation is allowed, for the sake of expediency, to produce
elements in a non-deterministic order.
Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}.
"""
def _apply_fn(dataset):
return ParallelInterleaveDataset(
dataset, map_func, cycle_length, block_length, sloppy)
return _apply_fn
@deprecation.deprecated(
None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.")
def sloppy_interleave(map_func, cycle_length, block_length=1):
"""A non-deterministic version of the `Dataset.interleave()` transformation.
@ -132,6 +184,6 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
@{tf.data.Dataset.apply}.
"""
def _apply_fn(dataset):
return SloppyInterleaveDataset(
dataset, map_func, cycle_length, block_length)
return ParallelInterleaveDataset(
dataset, map_func, cycle_length, block_length, sloppy=True)
return _apply_fn

View File

@ -5924,8 +5924,8 @@ tf_kernel_library(
)
tf_kernel_library(
name = "sloppy_interleave_dataset_op",
srcs = ["sloppy_interleave_dataset_op.cc"],
name = "parallel_interleave_dataset_op",
srcs = ["parallel_interleave_dataset_op.cc"],
deps = [
":captured_function",
":dataset",
@ -6162,6 +6162,7 @@ tf_kernel_library(
":map_and_batch_dataset_op",
":map_dataset_op",
":padded_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
":prefetch_dataset_op",
":range_dataset_op",
@ -6170,7 +6171,6 @@ tf_kernel_library(
":scan_dataset_op",
":shuffle_dataset_op",
":skip_dataset_op",
":sloppy_interleave_dataset_op",
":sparse_tensor_slice_dataset_op",
":sql_dataset_ops",
":take_dataset_op",

View File

@ -336,7 +336,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
const std::unique_ptr<CapturedFunction> captured_func_;
const Eigen::ThreadPoolDevice* device_; // not owned
const Eigen::ThreadPoolDevice* device_; // not owned
};
const int graph_def_version_;

View File

@ -17,12 +17,11 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/captured_function.h"
#include "tensorflow/core/kernels/dataset_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/kernels/captured_function.h"
namespace tensorflow {
namespace {
@ -30,9 +29,9 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit SloppyInterleaveDatasetOp(OpKernelConstruction* ctx)
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
@ -62,13 +61,16 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0"));
bool sloppy;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_,
std::move(other_arguments),
&captured_func));
*output = new Dataset(input, std::move(captured_func), cycle_length,
block_length, output_types_, output_shapes_);
block_length, sloppy, output_types_, output_shapes_);
}
private:
@ -76,12 +78,13 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
Dataset(const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, const DataTypeVector& output_types,
int64 block_length, bool sloppy, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: input_(input),
captured_func_(std::move(captured_func)),
cycle_length_(cycle_length),
block_length_(block_length),
sloppy_(sloppy),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
@ -91,8 +94,8 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIterator(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::SloppyInterleave")}));
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::ParallelInterleave")}));
}
const DataTypeVector& output_dtypes() const override {
@ -103,7 +106,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
string DebugString() override {
return "SloppyInterleaveDatasetOp::Dataset";
return "ParallelInterleaveDatasetOp::Dataset";
}
private:
@ -131,16 +134,24 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
// Search for available items, blocking if necessary.
const int64 num_workers = worker_threads_.size();
if (num_workers == 0) {
*end_of_sequence = true;
return Status::OK();
}
while (!cancelled_) {
for (size_t i = 0; i < dataset()->cycle_length_; ++i) {
size_t index = (next_index_ + i) % dataset()->cycle_length_;
// Wait for an item to become available, blocking if necessary. If we
// are allowed to be sloppy, we can skip over input datasets that do
// not have an item readily available.
const int64 n = dataset()->sloppy_ ? num_workers : 1LL;
for (int64 i = 0; i < n; ++i) {
int64 index = (next_index_ + i) % num_workers;
if (output_elements_[index].is_produced) {
next_index_ = index;
if (i == 0) {
block_count_++;
if (block_count_ == dataset()->block_length_) {
next_index_ = (index + 1) % dataset()->cycle_length_;
next_index_ = (index + 1) % num_workers;
block_count_ = 0;
}
} else {
@ -150,7 +161,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (output_elements_[index].end_of_sequence) {
output_elements_[index].is_produced = false;
output_elements_[index].cond_var.notify_one();
next_index_ = (index + 1) % dataset()->cycle_length_;
next_index_ = (index + 1) % num_workers;
block_count_ = 0;
i = -1; // Restart the inner loop
continue;
@ -174,11 +185,21 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
// If we are not allowed to be sloppy and
// `worker_threads_[next_index]` has finished, advance `next_index`.
if (!dataset()->sloppy_ && worker_threads_[next_index_].finished) {
next_index_ = (next_index_ + 1) % num_workers;
continue;
}
// No values available; wait until woken up.
// TODO(jsimsa): Use slot-specific condition variable for
// coordination of elements consumption.
cond_var_.wait(l);
}
return errors::Cancelled(
"SloppyInterleaveDatasetOp::Dataset::Iterator::GetNext");
"ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
}
private:
@ -201,6 +222,16 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
condition_variable cond_var;
};
struct ThreadStatus {
// The underlying thread uses `finished` to communicate to the producer
// that it has finished.
bool finished = false;
// The underlying thread object.
std::unique_ptr<Thread> thread;
explicit ThreadStatus(Thread* thread) : thread(thread) {}
};
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (worker_threads_.empty()) {
@ -220,11 +251,10 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> itr;
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr));
worker_threads_.emplace_back(
std::unique_ptr<Thread>(ctx->env()->StartThread(
{}, "worker_thread",
std::bind(&Iterator::WorkerThread, this,
new IteratorContext(*ctx), i, itr.release()))));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
std::bind(&Iterator::WorkerThread, this,
new IteratorContext(*ctx), i, itr.release())));
num_active_threads_ = i + 1;
}
}
@ -264,6 +294,7 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr);
auto cleanup = gtl::MakeCleanup([this, thread_index] {
mutex_lock l(mu_);
worker_threads_[thread_index].finished = true;
num_active_threads_--;
cond_var_.notify_all();
});
@ -345,13 +376,14 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Pointers to the worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads.
std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
std::vector<ThreadStatus> worker_threads_ GUARDED_BY(mu_);
};
const DatasetBase* const input_;
const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_;
const int64 block_length_;
const bool sloppy_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
@ -362,8 +394,8 @@ class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel {
NameAttrList func_;
};
REGISTER_KERNEL_BUILDER(Name("SloppyInterleaveDataset").Device(DEVICE_CPU),
SloppyInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
} // namespace

View File

@ -59,7 +59,6 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
Dataset(const DatasetBase* input, int64 buffer_size,
IteratorContext::Params ctx_params)
: input_(input),
buffer_size_(buffer_size),
ctx_params_(std::move(ctx_params)) {
input_->Ref();

View File

@ -32629,95 +32629,6 @@ op {
}
}
}
op {
name: "SloppyInterleaveDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "other_arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "cycle_length"
type: DT_INT64
}
input_arg {
name: "block_length"
type: DT_INT64
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "f"
type: "func"
}
attr {
name: "Targuments"
type: "list(type)"
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
is_stateful: true
}
op {
name: "SloppyInterleaveDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "other_arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "cycle_length"
type: DT_INT64
}
input_arg {
name: "block_length"
type: DT_INT64
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "f"
type: "func"
}
attr {
name: "Targuments"
type: "list(type)"
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
}
op {
name: "Softmax"
input_arg {

View File

@ -285,11 +285,12 @@ f: A function mapping elements of `input_dataset`, concatenated with
`output_types` and `output_shapes`.
)doc");
REGISTER_OP("SloppyInterleaveDataset")
REGISTER_OP("ParallelInterleaveDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Input("cycle_length: int64")
.Input("block_length: int64")
.Input("sloppy: bool")
.Output("handle: variant")
.Attr("f: func")
.Attr("Targuments: list(type) >= 0")