Add prefetching into parallel_interleave

This change adds 2 parameters to parallel_interleave:
 - prefetch_input_elements: determines the number of iterators to prefetch
     allowing buffers to warm up and data to be pre-fetched without blocking
     the main thread (i.e. the GetNext() call).
 - buffer_output_elements: in order to avoid creating thousands of threads, we
     fuse in the .prefetch() operator as an additional parameter. The value of
     this parameter is identical to the value passed to `.prefetch()`

PiperOrigin-RevId: 179726088
This commit is contained in:
Brennan Saeta 2017-12-20 13:29:07 -08:00 committed by TensorFlower Gardener
parent 7d94d7672a
commit 76db97fe39
11 changed files with 678 additions and 231 deletions

View File

@ -41,6 +41,7 @@ from tensorflow.python.platform import test
class InterleaveDatasetTest(test.TestCase):
def _interleave(self, lists, cycle_length, block_length):
# TODO(b/69678297): Consolidate python interleave implementations.
num_open = 0
# `all_iterators` acts as a queue of iterators over each element of `lists`.
@ -255,11 +256,15 @@ class InterleaveDatasetSeriazationTest(
class ParallelInterleaveDatasetTest(test.TestCase):
def setUp(self):
self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[])
self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[])
self.error = None
self.repeat_count = 2
# Set up threading events used to sequence when items are produced that
@ -276,6 +281,10 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.write_coordination_events[x].wait()
self.write_coordination_events[x].clear()
self.read_coordination_events[x].release()
if self.error:
err = self.error
self.error = None
raise err # pylint: disable=raising-bad-type
return x * x
def map_fn(x):
@ -286,11 +295,13 @@ class ParallelInterleaveDatasetTest(test.TestCase):
dataset = dataset.repeat(x)
return dataset.map(map_fn)
self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values)
.repeat(self.repeat_count).apply(
interleave_ops.parallel_interleave(
interleave_fn, self.cycle_length,
self.block_length, self.sloppy)))
self.dataset = (
dataset_ops.Dataset.from_tensor_slices(self.input_values)
.repeat(self.repeat_count).apply(
interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
self.block_length, self.sloppy,
self.buffer_output_elements,
self.prefetch_input_elements)))
self.iterator = self.dataset.make_initializable_iterator()
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
@ -380,7 +391,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
for i in range(4, 7):
self.write_coordination_events[i].set()
def _testSingleThreaded(self, sloppy=False):
def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
with self.test_session() as sess:
@ -391,7 +402,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.input_values: [4, 5, 6],
self.cycle_length: 1,
self.block_length: 1,
self.sloppy: sloppy
self.sloppy: sloppy,
self.buffer_output_elements: 1,
self.prefetch_input_elements: prefetch_input_elements,
})
for expected_element in self._interleave(
@ -408,6 +421,41 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testSingleThreadedSloppy(self):
self._testSingleThreaded(sloppy=True)
def testSingleThreadedPrefetch1Itr(self):
self._testSingleThreaded(prefetch_input_elements=1)
def testSingleThreadedPrefetch1ItrSloppy(self):
self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
def testSingleThreadedRagged(self):
# Tests a sequence with wildly different elements per iterator.
with self.test_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
feed_dict={
self.input_values: [3, 7, 4],
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: False,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 1,
})
# Add coordination values for 3 and 7
self.read_coordination_events[3] = threading.Semaphore(0)
self.write_coordination_events[3] = threading.Event()
self.read_coordination_events[7] = threading.Semaphore(0)
self.write_coordination_events[7] = threading.Event()
for expected_element in self._interleave(
[[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
self.write_coordination_events[expected_element].set()
output = sess.run(self.next_element)
self.assertEqual(expected_element * expected_element, output)
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def _testTwoThreadsNoContention(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
@ -420,7 +468,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: sloppy
self.sloppy: sloppy,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 1,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -463,6 +513,8 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: sloppy,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 1,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -502,7 +554,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 2,
self.sloppy: sloppy
self.sloppy: sloppy,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 1,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -545,7 +599,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 2,
self.sloppy: sloppy
self.sloppy: sloppy,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 1,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
@ -583,7 +639,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.input_values: [],
self.cycle_length: 2,
self.block_length: 3,
self.sloppy: sloppy
self.sloppy: sloppy,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 0,
})
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
@ -604,7 +662,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.input_values: [0, 0, 0],
self.cycle_length: 2,
self.block_length: 3,
self.sloppy: sloppy
self.sloppy: sloppy,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 0,
})
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
@ -615,7 +675,8 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testNonEmptyInputIntoEmptyOutputsSloppy(self):
self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
def _testPartiallyEmptyOutputs(self, sloppy=False):
def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
# Mixture of non-empty and empty interleaved datasets.
with self.test_session() as sess:
self._clear_coordination_events()
@ -627,27 +688,31 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: sloppy,
self.buffer_output_elements: 1,
self.prefetch_input_elements: prefetch_input_elements,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
self.write_coordination_events[expected_element].set()
if done_first_event: # First event starts the worker threads
# First event starts the worker threads. Additionally, when running the
# sloppy case with prefetch_input_elements=0, we get stuck if we wait
# for the read coordination event for certain event orderings in the
# presence of finishing iterators.
if done_first_event and not (sloppy and (i in race_indices)):
self.read_coordination_events[expected_element].acquire()
actual_element = sess.run(self.next_element)
if not done_first_event:
if not done_first_event or (sloppy and (i in race_indices)):
done_first_event = True
self.read_coordination_events[expected_element].acquire()
self.assertEqual(expected_element * expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testPartiallyEmptyOutputs(self):
self._testPartiallyEmptyOutputs()
def testPartiallyEmptyOutputsSloppy(self):
self._testPartiallyEmptyOutputs(sloppy=True)
self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid
@ -661,6 +726,8 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: True,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 0,
})
mis_ordering = [
@ -683,8 +750,10 @@ class ParallelInterleaveDatasetTest(test.TestCase):
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 3,
self.sloppy: True
self.block_length: 1,
self.sloppy: True,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 1,
})
# Test against a generating sequence that differs from the uncontended
# case, in order to prove sloppy correctness.
@ -692,7 +761,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self._interleave(
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count,
cycle_length=2,
block_length=2)):
block_length=3)):
self.write_coordination_events[expected_element].set()
if done_first_event: # First event starts the worker threads.
self.read_coordination_events[expected_element].acquire()
@ -716,7 +785,9 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.input_values: [4, 5, 6],
self.cycle_length: 3,
self.block_length: 2,
self.sloppy: sloppy
self.sloppy: sloppy,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 0,
})
for i in range(4, 7):
self.write_coordination_events[i].set()
@ -790,6 +861,139 @@ class ParallelInterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testErrorsInOutputFn(self):
with self.test_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: False,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 0,
})
except_on_element_indices = set([3])
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
1)):
if i in except_on_element_indices:
self.error = ValueError()
self.write_coordination_events[expected_element].set()
with self.assertRaises(errors.InvalidArgumentError):
sess.run(self.next_element)
else:
self.write_coordination_events[expected_element].set()
actual_element = sess.run(self.next_element)
self.assertEqual(expected_element * expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testErrorsInInputFn(self):
def map_py_fn(x):
if x == 5:
raise ValueError()
return x
def map_fn(x):
return script_ops.py_func(map_py_fn, [x], x.dtype)
def interleave_fn(x):
dataset = dataset_ops.Dataset.from_tensors(x)
dataset = dataset.repeat(x)
return dataset
self.dataset = (
dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn)
.repeat(self.repeat_count).apply(
interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
self.block_length, self.sloppy,
self.buffer_output_elements,
self.prefetch_input_elements)))
self.iterator = self.dataset.make_initializable_iterator()
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
with self.test_session() as sess:
sess.run(
self.init_op,
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: False,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 0,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
if expected_element == 5:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(self.next_element)
else:
actual_element = sess.run(self.next_element)
self.assertEqual(expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
def testErrorsInInterleaveFn(self):
def map_py_fn(x):
if x == 5:
raise ValueError()
return x
def interleave_fn(x):
dataset = dataset_ops.Dataset.from_tensors(x)
y = script_ops.py_func(map_py_fn, [x], x.dtype)
dataset = dataset.repeat(y)
return dataset
self.dataset = (
dataset_ops.Dataset.from_tensor_slices(self.input_values)
.repeat(self.repeat_count).apply(
interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
self.block_length, self.sloppy,
self.buffer_output_elements,
self.prefetch_input_elements)))
self.iterator = self.dataset.make_initializable_iterator()
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
with self.test_session() as sess:
sess.run(
self.init_op,
feed_dict={
self.input_values: [4, 5, 6],
self.cycle_length: 2,
self.block_length: 1,
self.sloppy: False,
self.buffer_output_elements: 1,
self.prefetch_input_elements: 0,
})
for i, expected_element in enumerate(
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
if expected_element == 5:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(self.next_element)
else:
actual_element = sess.run(self.next_element)
self.assertEqual(expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.next_element)
if __name__ == "__main__":
test.main()

View File

@ -122,6 +122,7 @@ py_library(
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//third_party/py/numpy",

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
@ -31,7 +32,7 @@ class ParallelInterleaveDataset(dataset_ops.Dataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func, cycle_length, block_length,
sloppy):
sloppy, buffer_output_elements, prefetch_input_elements):
"""See `tf.contrib.data.parallel_interleave()` for details."""
super(ParallelInterleaveDataset, self).__init__()
self._input_dataset = input_dataset
@ -74,6 +75,14 @@ class ParallelInterleaveDataset(dataset_ops.Dataset):
block_length, dtype=dtypes.int64, name="block_length")
self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
self._buffer_output_elements = convert.optional_param_to_tensor(
"buffer_output_elements",
buffer_output_elements,
argument_default=2 * block_length)
self._prefetch_input_elements = convert.optional_param_to_tensor(
"prefetch_input_elements",
prefetch_input_elements,
argument_default=2 * cycle_length)
def _as_variant_tensor(self):
return gen_dataset_ops.parallel_interleave_dataset(
@ -82,6 +91,8 @@ class ParallelInterleaveDataset(dataset_ops.Dataset):
self._cycle_length,
self._block_length,
self._sloppy,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func,
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)),
@ -101,7 +112,12 @@ class ParallelInterleaveDataset(dataset_ops.Dataset):
return self._output_types
def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
def parallel_interleave(map_func,
cycle_length,
block_length=1,
sloppy=False,
buffer_output_elements=None,
prefetch_input_elements=None):
"""A parallel version of the `Dataset.interleave()` transformation.
`parallel_interleave()` maps `map_func` across its input to produce nested
@ -129,12 +145,17 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
Args:
map_func: A function mapping a nested structure of tensors to a `Dataset`.
cycle_length: The number of threads to interleave from in parallel.
block_length: The number of consecutive elements to pull from a thread
before advancing to the next thread.
cycle_length: The number of input `Dataset`s to interleave from in parallel.
block_length: The number of consecutive elements to pull from an input
`Dataset` before advancing to the next input `Dataset`.
sloppy: If false, elements are produced in deterministic order. Otherwise,
the implementation is allowed, for the sake of expediency, to produce
elements in a non-deterministic order.
buffer_output_elements: The number of elements each iterator being
interleaved should buffer (similar to the `.prefetch()` transformation for
each interleaved iterator).
prefetch_input_elements: The number of input elements to transform to
iterators before they are needed for interleaving.
Returns:
A `Dataset` transformation function, which can be passed to
@ -142,7 +163,9 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False):
"""
def _apply_fn(dataset):
return ParallelInterleaveDataset(
dataset, map_func, cycle_length, block_length, sloppy)
dataset, map_func, cycle_length, block_length, sloppy,
buffer_output_elements, prefetch_input_elements)
return _apply_fn
@ -187,11 +210,11 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
`Dataset`.
cycle_length: The number of threads to interleave from in parallel.
block_length: The number of consecutive elements to pull from a thread
before advancing to the next thread. Note: sloppy_interleave will
skip the remainder of elements in the block_length in order to avoid
blocking.
cycle_length: The number of input `Dataset`s to interleave from in parallel.
block_length: The number of consecutive elements to pull from an input
`Dataset` before advancing to the next input `Dataset`. Note:
`sloppy_interleave` will skip the remainder of elements in the
`block_length` in order to avoid blocking.
Returns:
A `Dataset` transformation function, which can be passed to
@ -199,5 +222,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
"""
def _apply_fn(dataset):
return ParallelInterleaveDataset(
dataset, map_func, cycle_length, block_length, sloppy=True)
dataset,
map_func,
cycle_length,
block_length,
sloppy=True,
buffer_output_elements=None,
prefetch_input_elements=None)
return _apply_fn

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <deque>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
@ -48,28 +50,44 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
other_arguments.push_back(t);
}
int64 cycle_length;
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
OP_REQUIRES(ctx, cycle_length > 0,
errors::InvalidArgument("`cycle_length` must be > 0"));
int64 block_length;
int64 block_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "block_length", &block_length));
OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0"));
bool sloppy;
bool sloppy = false;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
int64 buffer_output_elements = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements",
&buffer_output_elements));
OP_REQUIRES(
ctx, buffer_output_elements > 0,
errors::InvalidArgument("`buffer_output_elements` must be > 0"));
int64 prefetch_input_elements = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements",
&prefetch_input_elements));
OP_REQUIRES(
ctx, prefetch_input_elements >= 0,
errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_,
std::move(other_arguments),
&captured_func));
*output = new Dataset(input, std::move(captured_func), cycle_length,
block_length, sloppy, output_types_, output_shapes_);
*output =
new Dataset(input, std::move(captured_func), cycle_length, block_length,
sloppy, buffer_output_elements, prefetch_input_elements,
output_types_, output_shapes_);
}
private:
@ -77,13 +95,16 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
Dataset(const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, bool sloppy, const DataTypeVector& output_types,
int64 block_length, bool sloppy, int64 buffer_output_elements,
int64 prefetch_input_elements, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: input_(input),
captured_func_(std::move(captured_func)),
cycle_length_(cycle_length),
block_length_(block_length),
sloppy_(sloppy),
buffer_output_elements_(buffer_output_elements),
prefetch_input_elements_(prefetch_input_elements),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
@ -109,244 +130,317 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
private:
int64 num_threads() const {
return cycle_length_ + prefetch_input_elements_;
}
// Parallel interleave's implementation is designed around a few principles:
// 1. Thread creation is relatively expensive. (Not reusing
// threads causes a number of indirect costs such as poorer tcmalloc
// performance due to thread-local caches, etc.) We allocate a fixed
// number of threads at the start and never change. This is why we've
// fused functionality that is theoretically orthogonal (i.e.
// .prefetch()) into the implementation.
// 2. Drop-in replacement for standard interleave. The goal will be to
// auto-opt people into an optimized implementation without any work
// on the customer's part. We thus go through great pains to maintain
// identical iteration orders, full determinism (disabled only via a
// flag, etc.)
// 3. Performance across a variety of environments and I/O envelopes.
//
// The actual implementation centers around a collection of worker threads
// and their corresponding worker state (tracked in the `workers_` vector).
// Worker threads repeatedly receive a vector of Tensors that are used as
// input to the flat-map function (`captured_func_`). The output of this
// function must be a dataset. The worker thread then repeatedly calls
// `GetNext()`, maintaining a buffer of elements to minimize the likelihood
// that a caller will block waiting for an element to be produced.
//
// Pointers to these worker states are kept in 2 disjoint data structures:
// 1. `interleave_` is a vector containing pointers to `WorkerState`s that
// we
// are interleaving. Worker threads backing these WorkerStates should
// be regularly producing values.
// 2. `staging_` is a deque containing pointers to WorkerStates that we
// will move to `interleave_` when an iterator in `interleave_` is
// exhausted.
//
// The client calls `GetNext[Internal]()` to retrieve an output element. The
// internal implementation updates the state of `interleave_` and `staging_`
// as output iterators (run by the worker threads) are exhausted.
//
// `input_impl_` is the input iterator that generates arguments for the
// flat-map function (`captured_func_`). It is set to an iterator at
// Iterator construction, and is fixed until we consume all input elements.
// Once it is exhausted, we reset the unique_ptr to eagerly deallocate
// memory.
//
// A few invariants are maintained:
// 1. No element in interleave_ should be a nullptr unless `staging_` is
// empty and `input_impl_` is empty.
// 2. Every `worker_` element is pointed to by at most one element of the
// union of `interleave_` and `staging_`.
// 3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
// an element in `interleave_` or `staging_`.
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
output_elements_(params.dataset->cycle_length_) {}
workers_(dataset()->num_threads()) {}
~Iterator() override {
mutex_lock l(mu_);
cancelled_ = true;
// Notify all workers in case they are blocked.
for (int64 i = 0; i < dataset()->cycle_length_; ++i) {
output_elements_[i].cond_var.notify_all();
for (auto& worker : workers_) {
worker.cond_var.notify_all();
}
}
// It is implemented so that it matches the deterministic interleave
// unless we would block waiting for an element, at which point it skips
// along to the next available value.
// unless getting the next element would block and we are allowed to be
// sloppy.
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
const int64 num_workers = worker_threads_.size();
if (num_workers == 0) {
*end_of_sequence = true;
return Status::OK();
}
while (!cancelled_) {
// Wait for an item to become available, blocking if necessary. If we
// are allowed to be sloppy, we can skip over input datasets that do
// not have an item readily available.
const int64 n = dataset()->sloppy_ ? num_workers : 1LL;
for (int64 i = 0; i < n; ++i) {
int64 index = (next_index_ + i) % num_workers;
if (output_elements_[index].is_produced) {
bool can_produce_elements = false;
bool must_wait_for_input = true;
for (int64 i = 0; i < interleave_.size(); ++i) {
int64 index = (next_index_ + i) % interleave_.size();
WorkerState* current_worker = interleave_[index];
if (!current_worker) continue; // Empty interleave elements.
can_produce_elements |= current_worker->MayHaveElements();
if (!current_worker->outputs.empty()) {
// We have an element!
next_index_ = index;
if (i == 0) {
block_count_++;
if (block_count_ == dataset()->block_length_) {
next_index_ = (index + 1) % num_workers;
next_index_ = (index + 1) % interleave_.size();
block_count_ = 0;
}
} else {
block_count_ = 0;
}
// If we encounter an EoF, advance to the next iterator
if (output_elements_[index].end_of_sequence) {
output_elements_[index].is_produced = false;
output_elements_[index].cond_var.notify_one();
next_index_ = (index + 1) % num_workers;
block_count_ = 0;
i = -1; // Restart the inner loop
continue;
}
*end_of_sequence = false;
if (output_elements_[index].output_status.ok()) {
output_elements_[index].output_value.swap(*out_tensors);
Status s = current_worker->outputs.front().status;
current_worker->outputs.front().output.swap(*out_tensors);
current_worker->outputs.pop_front();
current_worker->cond_var.notify_one();
return s;
} else if (current_worker->is_producing && !dataset()->sloppy_) {
// current_worker.outputs.empty(), and we must wait for this
// iterator.
if (next_index_ != index) {
// We have advanced to a new iterator; reset block counts.
next_index_ = index;
block_count_ = 0;
}
break;
} else if (!current_worker->is_producing) {
// This iterator has reached end of input.
interleave_[index] = nullptr;
if (input_impl_) {
// Start prefetching a new iterator.
std::vector<Tensor> args;
bool end_of_input = false;
Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
if (end_of_input) {
input_impl_.reset();
} else {
current_worker->SetInputs(s, std::move(args));
staging_.emplace_back(current_worker);
}
}
if (!staging_.empty()) {
// Move a worker from `staging_` to `interleave_`.
interleave_[index] = staging_.front();
staging_.pop_front();
next_index_ = (index + 1) % interleave_.size();
block_count_ = 0;
// Restart the inner [for] loop
can_produce_elements = true;
must_wait_for_input = false;
break;
}
output_elements_[index].is_produced = false;
output_elements_[index].cond_var.notify_one();
return output_elements_[index].output_status;
}
}
if (num_active_threads_ == 0) {
if (!can_produce_elements && !input_impl_) {
// No potential for future values.
//
// Note: this condition check must occur after checking the output
// buffer, as its possible for there to be values in the output
// buffer, even if the number of live threads is zero.
*end_of_sequence = true;
return Status::OK();
}
// If we are not allowed to be sloppy and
// `worker_threads_[next_index]` has finished, advance `next_index`.
if (!dataset()->sloppy_ && worker_threads_[next_index_].finished) {
next_index_ = (next_index_ + 1) % num_workers;
continue;
if (must_wait_for_input) {
// Wait for elements to become available.
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
interleave_[next_index_]->cond_var.wait(l);
}
}
// No values available; wait until woken up.
// TODO(jsimsa): Use slot-specific condition variable for
// coordination of elements consumption.
cond_var_.wait(l);
}
return errors::Cancelled(
"ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
}
private:
// Internal structure to manage thread coordination. All values are
// guarded by the enclosing Iterator's mu_.
struct OutputBufferElement {
// The producer must set `is_produced` to `true` after
// `output_status` or `output_value` has been written.
bool is_produced = false;
// The producer sets `output_status` if either getting the input element
// or applying the function to it fails.
Status output_status;
// Reached end of sequence for the underlying iterator.
bool end_of_sequence = false;
// The output data element.
std::vector<Tensor> output_value;
// The producer thread waits on this condition variable after having
// produced an element. The reader thread notifies this condition
// variable after reading the value.
condition_variable cond_var;
// OutputElem contains the information from a call to GetNext by an output
// iterator.
struct OutputElem {
// The output iterator sets `status` if getting the output element
// fails.
Status status;
// The buffered data element.
std::vector<Tensor> output;
explicit OutputElem(const Status& s) : status(s) {}
};
struct ThreadStatus {
// The underlying thread uses `finished` to communicate to the producer
// that it has finished.
bool finished = false;
// The underlying thread object.
std::unique_ptr<Thread> thread;
// Worker threads operate on their relevant WorkerState structs.
//
// WorkerState's fields are all protected by mu_;
struct WorkerState {
// The arguments to be used to construct an output iterator.
std::vector<Tensor> input;
// The buffered output elements.
std::deque<OutputElem> outputs;
// Set to true iff the worker thread expects to append more elements to
// outputs. is_producing can be false despite !outputs.empty().
// Concretely, all output elements will have been consumed only when:
// is_producing == false && outputs.empty();
bool is_producing = false;
// Condition variable used to coordinate between threads. The worker
// thread waits on this condition variable when it is either (1) waiting
// for the main thread to add arguments to `input`, or (2) waiting for
// the main thread to consume an element of `outputs`. The main thread
// waits on cond_var if it is waiting for the worker thread to produce
// an element into `outputs` (this implies sloppy_==false).
condition_variable cond_var;
explicit ThreadStatus(Thread* thread) : thread(thread) {}
inline bool MayHaveElements() const {
return is_producing || !outputs.empty();
}
// Sets inputs for a worker thread and notifies it to start processing.
void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
if (s.ok()) {
DCHECK(!MayHaveElements())
<< "Tried to start inputs, despite already producing!";
input = std::move(input_arguments);
is_producing = true;
cond_var.notify_one();
} else {
outputs.emplace_back(s);
}
}
};
Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (worker_threads_.empty()) {
for (int64 i = 0; i < dataset()->cycle_length_; ++i) {
// Serialize the creation of the workers and their corresponding
// input elements to ensure we match the standard interleave when
// the underlying iterators induce no delay.
worker_threads_.reserve(dataset()->num_threads());
for (int64 i = 0; i < dataset()->num_threads(); ++i) {
std::vector<Tensor> args;
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, &args, &end_of_input_));
if (end_of_input_) {
LOG(WARNING) << "Input iterator exhausted after " << i
<< " elements; cannot start all "
<< dataset()->cycle_length_ << " worker threads.";
bool end_of_input = false;
Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
if (end_of_input) {
input_impl_.reset();
return Status::OK();
}
std::unique_ptr<IteratorBase> itr;
TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr));
workers_[i].SetInputs(s, std::move(args));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
std::bind(&Iterator::WorkerThread, this,
new IteratorContext(*ctx), i, itr.release())));
num_active_threads_ = i + 1;
new IteratorContext(*ctx), i)));
if (i < dataset()->cycle_length_) {
interleave_.push_back(&workers_[i]);
} else {
staging_.push_back(&workers_[i]);
}
}
DCHECK(interleave_.size() == dataset()->cycle_length_);
DCHECK(staging_.size() == dataset()->prefetch_input_elements_);
}
return Status::OK();
}
void BlockAndUpdateOutputBuffer(mutex_lock* l, const int64 thread_index,
const Status& status,
bool end_of_sequence,
std::vector<Tensor>* out_tensors)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
// We have produced an element; push it into the output buffer
// when space is available.
while (!cancelled_ && output_elements_[thread_index].is_produced) {
output_elements_[thread_index].cond_var.wait(*l);
}
if (cancelled_) {
return;
}
output_elements_[thread_index].is_produced = true;
output_elements_[thread_index].output_status = status;
output_elements_[thread_index].end_of_sequence = end_of_sequence;
if (status.ok()) {
output_elements_[thread_index].output_value.swap(*out_tensors);
} else {
output_elements_[thread_index].output_value.clear();
}
cond_var_.notify_one();
}
// Races to produce elements into the output queue buffers.
void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index,
IteratorBase* out_iterator_ptr) {
// Produces elements into the worker's output buffers.
void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
std::unique_ptr<IteratorContext> ctx(ctx_ptr);
std::unique_ptr<IteratorBase> out_iterator(out_iterator_ptr);
auto cleanup = gtl::MakeCleanup([this, thread_index] {
mutex_lock l(mu_);
worker_threads_[thread_index].finished = true;
num_active_threads_--;
cond_var_.notify_all();
workers_[thread_index].cond_var.notify_all();
});
while (true) {
// Attempt to produce an element.
bool end_of_out_itr_input = false;
std::vector<Tensor> out_tensors;
Status element_status = out_iterator->GetNext(ctx.get(), &out_tensors,
&end_of_out_itr_input);
// Handle output.
// 1. Wait for input.
std::vector<Tensor> input;
{
mutex_lock l(mu_);
BlockAndUpdateOutputBuffer(&l, thread_index, element_status,
end_of_out_itr_input, &out_tensors);
if (end_of_out_itr_input) {
// We have exhausted our current iterator; get a new iterator;
// loop to handle errors.
while (!cancelled_) {
if (end_of_input_) {
// No more iterator inputs; we're done!
return;
}
std::vector<Tensor> args;
// BlockAndUpdateOutputBuffer() sequences calls to
// input_impl_->GetNext when the out_iterator doesn't cause
// slopping.
Status input_status =
input_impl_->GetNext(ctx.get(), &args, &end_of_input_);
if (end_of_input_) {
// No more elements to produce, stop the worker thread.
return;
}
if (input_status.ok()) {
input_status = dataset::MakeIteratorFromInputElement(
ctx.get(), args, thread_index,
dataset()->captured_func_.get(), prefix(), &out_iterator);
}
if (input_status.ok()) {
// Successfully have a new out_iterator; restart the outer
// loop to produce an element.
break;
}
// We encountered an error; push the error to the output buffer.
BlockAndUpdateOutputBuffer(&l, thread_index, input_status,
/* end_of_sequence = */ false,
&out_tensors);
}
while (!cancelled_ && !workers_[thread_index].is_producing) {
workers_[thread_index].cond_var.wait(l);
}
if (cancelled_) return;
input.swap(workers_[thread_index].input);
}
// Check if we should exit.
if (cancelled_) {
return;
// 2. Run the user defined function to produce a new iterator.
std::unique_ptr<IteratorBase> iterator;
Status s = dataset::MakeIteratorFromInputElement(
ctx.get(), input, thread_index, dataset()->captured_func_.get(),
prefix(), &iterator);
input.clear(); // Release memory as early as possible.
if (!s.ok()) {
mutex_lock l(mu_);
workers_[thread_index].outputs.emplace_back(s);
workers_[thread_index].is_producing = false;
workers_[thread_index].cond_var.notify_one();
} else {
// 3. Produce elements
bool end_of_sequence = false;
while (!end_of_sequence) {
// 3.a Produce an element!
std::vector<Tensor> output_elem;
s = iterator->GetNext(ctx.get(), &output_elem, &end_of_sequence);
// 3.b Make it available to the client.
{
mutex_lock l(mu_);
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
workers_[thread_index].cond_var.wait(l);
}
if (cancelled_) return;
// Output the element.
workers_[thread_index].is_producing = !end_of_sequence;
if (!end_of_sequence) {
workers_[thread_index].outputs.emplace_back(s);
workers_[thread_index].outputs.back().output.swap(
output_elem);
}
if (dataset()->sloppy_) {
sloppy_cond_var_.notify_one();
} else {
workers_[thread_index].cond_var.notify_one();
}
}
}
}
}
@ -355,27 +449,34 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Mutex & condition variable to guard mutable iterator internals and
// coordinate among worker threads and client thread[s].
mutex mu_;
condition_variable cond_var_;
// The main thread waits on this condition variable if running in sloppy
// mode and no values are available.
condition_variable sloppy_cond_var_;
// The iterator producing elements which are converted to datasets by
// the dataset()->captured_func_ then interleaved together.
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
// Whether the input_impl_ can produce future elements.
bool end_of_input_ GUARDED_BY(mu_) = false;
// The buffer of elements to be produced. Each worker thread operates
// on a single OutputBufferElement.
std::vector<OutputBufferElement> output_elements_ GUARDED_BY(mu_);
// input_impl_ is reset when we have exhausted its input.
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
// The WorkerState structs the worker threads operate on.
// workers_ elements are in at most one of interleave_ and staging_.
std::vector<WorkerState> workers_ GUARDED_BY(mu_);
// The iterators to interleave
std::vector<WorkerState*> interleave_ GUARDED_BY(mu_);
// Prefetched iterators
std::deque<WorkerState*> staging_ GUARDED_BY(mu_);
// The index into output_elements_ for next element to produce.
size_t next_index_ GUARDED_BY(mu_) = 0;
// The number of items produced so far within the block
size_t block_count_ GUARDED_BY(mu_) = 0;
// Number of active threads.
size_t num_active_threads_ GUARDED_BY(mu_) = 0;
// Flag to instruct the worker threads to exit.
bool cancelled_ GUARDED_BY(mu_) = false;
// Pointers to the worker threads. This must be last to ensure the
// The worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads.
std::vector<ThreadStatus> worker_threads_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
};
const DatasetBase* const input_;
@ -383,6 +484,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const int64 cycle_length_;
const int64 block_length_;
const bool sloppy_;
const int64 buffer_output_elements_;
const int64 prefetch_input_elements_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};

View File

@ -27387,6 +27387,14 @@ op {
name: "sloppy"
type: DT_BOOL
}
input_arg {
name: "buffer_output_elements"
type: DT_INT64
}
input_arg {
name: "prefetch_input_elements"
type: DT_INT64
}
output_arg {
name: "handle"
type: DT_VARIANT

View File

@ -313,6 +313,8 @@ REGISTER_OP("ParallelInterleaveDataset")
.Input("cycle_length: int64")
.Input("block_length: int64")
.Input("sloppy: bool")
.Input("buffer_output_elements: int64")
.Input("prefetch_input_elements: int64")
.Output("handle: variant")
.Attr("f: func")
.Attr("Targuments: list(type) >= 0")

View File

@ -34,11 +34,11 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":dataset_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/util:convert",
],
)

View File

@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.framework import constant_op
from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -29,18 +29,6 @@ from tensorflow.python.ops import gen_dataset_ops
_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB
def _convert_optional_param_to_tensor(argument_name,
argument_value,
argument_default=0,
argument_dtype=dtypes.int64):
if argument_value is not None:
return ops.convert_to_tensor(
argument_value, dtype=argument_dtype, name=argument_name)
else:
return constant_op.constant(
argument_default, dtype=argument_dtype, name=argument_name)
class TextLineDataset(Dataset):
"""A `Dataset` comprising lines from one or more text files."""
@ -58,12 +46,12 @@ class TextLineDataset(Dataset):
super(TextLineDataset, self).__init__()
self._filenames = ops.convert_to_tensor(
filenames, dtype=dtypes.string, name="filenames")
self._compression_type = _convert_optional_param_to_tensor(
self._compression_type = convert.optional_param_to_tensor(
"compression_type",
compression_type,
argument_default="",
argument_dtype=dtypes.string)
self._buffer_size = _convert_optional_param_to_tensor(
self._buffer_size = convert.optional_param_to_tensor(
"buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
def _as_variant_tensor(self):
@ -100,12 +88,12 @@ class TFRecordDataset(Dataset):
# Force the type to string even if filenames is an empty list.
self._filenames = ops.convert_to_tensor(
filenames, dtypes.string, name="filenames")
self._compression_type = _convert_optional_param_to_tensor(
self._compression_type = convert.optional_param_to_tensor(
"compression_type",
compression_type,
argument_default="",
argument_dtype=dtypes.string)
self._buffer_size = _convert_optional_param_to_tensor(
self._buffer_size = convert.optional_param_to_tensor(
"buffer_size",
buffer_size,
argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
@ -155,11 +143,11 @@ class FixedLengthRecordDataset(Dataset):
self._record_bytes = ops.convert_to_tensor(
record_bytes, dtype=dtypes.int64, name="record_bytes")
self._header_bytes = _convert_optional_param_to_tensor(
self._header_bytes = convert.optional_param_to_tensor(
"header_bytes", header_bytes)
self._footer_bytes = _convert_optional_param_to_tensor(
self._footer_bytes = convert.optional_param_to_tensor(
"footer_bytes", footer_bytes)
self._buffer_size = _convert_optional_param_to_tensor(
self._buffer_size = convert.optional_param_to_tensor(
"buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
def _as_variant_tensor(self):

View File

@ -62,6 +62,30 @@ py_test(
],
)
py_library(
name = "convert",
srcs = ["convert.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
],
)
py_test(
name = "convert_test",
size = "small",
srcs = ["convert_test.py"],
srcs_version = "PY2AND3",
deps = [
":convert",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:util",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -0,0 +1,34 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpers constructing Datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
def optional_param_to_tensor(argument_name,
argument_value,
argument_default=0,
argument_dtype=dtypes.int64):
if argument_value is not None:
return ops.convert_to_tensor(
argument_value, dtype=argument_dtype, name=argument_name)
else:
return constant_op.constant(
argument_default, dtype=argument_dtype, name=argument_name)

View File

@ -0,0 +1,53 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utilities working with user input."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class ConvertTest(test.TestCase):
def testInteger(self):
resp = convert.optional_param_to_tensor("foo", 3)
with self.test_session() as sess:
self.assertEqual(3, sess.run(resp))
def testIntegerDefault(self):
resp = convert.optional_param_to_tensor("foo", None)
with self.test_session() as sess:
self.assertEqual(0, sess.run(resp))
def testStringDefault(self):
resp = convert.optional_param_to_tensor("bar", None, "default",
dtypes.string)
with self.test_session() as sess:
self.assertEqual(compat.as_bytes("default"), sess.run(resp))
def testString(self):
resp = convert.optional_param_to_tensor("bar", "value", "default",
dtypes.string)
with self.test_session() as sess:
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
if __name__ == "__main__":
test.main()