[tf.data] Replacing default arguments values for tf.data.Dataset.interleave API with None, while preserving original behavior.

PiperOrigin-RevId: 305809806
Change-Id: I7471c4d787014d335b393346a356d9ebcdf845ff
This commit is contained in:
Jiri Simsa 2020-04-09 19:37:21 -07:00 committed by TensorFlower Gardener
parent 51ae56daae
commit c003ecc1d6
10 changed files with 64 additions and 30 deletions

View File

@ -76,6 +76,8 @@ namespace data {
ParallelInterleaveDatasetOp::kDeterministic;
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
namespace {
constexpr char kTfDataParallelInterleaveWorkerPool[] =
"tf_data_parallel_interleave_worker_pool";
constexpr char kParallelism[] = "parallelism";
@ -113,7 +115,10 @@ constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L;
// Period between reporting dataset statistics.
constexpr int kStatsReportingPeriodMillis = 1000;
namespace {
inline int64 CeilDiv(int64 numerator, int64 denominator) {
return (numerator + denominator - 1) / denominator;
}
int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements,
int64 block_length) {
if (configured_buffer_output_elements != model::kAutotune) {
@ -140,6 +145,7 @@ int64 OpVersionFromOpName(absl::string_view op_name) {
return 4;
}
}
} // namespace
// The motivation for creating an alternative implementation of parallel
@ -1522,14 +1528,6 @@ ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase* input,
DatasetBase** output) {
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
if (cycle_length == model::kAutotune) {
cycle_length = port::NumSchedulableCPUs();
}
OP_REQUIRES(ctx, cycle_length > 0,
errors::InvalidArgument("`cycle_length` must be > 0"));
int64 block_length = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
OP_REQUIRES(ctx, block_length > 0,
@ -1561,6 +1559,24 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
OP_REQUIRES(
ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
errors::InvalidArgument("num_parallel_calls must be greater than zero."));
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
if (cycle_length == model::kAutotune) {
if (num_parallel_calls != model::kAutotune) {
cycle_length = std::min(num_parallel_calls,
static_cast<int64>(port::MaxParallelism()));
} else {
// If parallelism is to be autotuned, we set the cycle length so that
// the number of thread created for the current and future cycle elements
// roughly matches the number of schedulable cores.
const int num_threads_per_cycle_length = kDefaultCyclePrefetchFactor + 1;
cycle_length =
CeilDiv(port::MaxParallelism(), num_threads_per_cycle_length);
}
}
OP_REQUIRES(ctx, cycle_length > 0,
errors::InvalidArgument("`cycle_length` must be > 0"));
OP_REQUIRES(
ctx, num_parallel_calls <= cycle_length,
errors::InvalidArgument(

View File

@ -35,13 +35,14 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
def _interleave(lists, cycle_length, block_length):
def _interleave(lists, cycle_length, block_length, num_parallel_calls=None):
"""Reference implementation of interleave used for testing.
Args:
lists: a list of lists to interleave
cycle_length: the length of the interleave cycle
block_length: the length of the interleave block
num_parallel_calls: the number of parallel calls
Yields:
Elements of `lists` interleaved in the order determined by `cycle_length`
@ -55,8 +56,15 @@ def _interleave(lists, cycle_length, block_length):
# `open_iterators` are the iterators whose elements are currently being
# interleaved.
open_iterators = []
if cycle_length == dataset_ops.AUTOTUNE:
cycle_length = multiprocessing.cpu_count()
if cycle_length is None:
# The logic here needs to match interleave C++ kernels.
if num_parallel_calls is None:
cycle_length = multiprocessing.cpu_count()
elif num_parallel_calls == dataset_ops.AUTOTUNE:
cycle_length = (multiprocessing.cpu_count() + 2) // 3
else:
cycle_length = min(num_parallel_calls, multiprocessing.cpu_count())
for i in range(cycle_length):
if all_iterators:
open_iterators.append(all_iterators.pop(0))
@ -162,7 +170,7 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
num_parallel_calls=[None, 1, 3, 5, 7]) +
combinations.combine(
input_values=[np.int64([4, 5, 6, 7])],
cycle_length=dataset_ops.AUTOTUNE,
cycle_length=None,
block_length=3,
num_parallel_calls=[None, 1]) + combinations.combine(
input_values=[np.int64([]), np.int64([0, 0, 0])],
@ -182,7 +190,8 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
cycle_length, block_length, num_parallel_calls)
expected_output = [
element for element in _interleave(
_repeat(input_values, count), cycle_length, block_length)
_repeat(input_values, count), cycle_length, block_length,
num_parallel_calls)
]
self.assertDatasetProduces(dataset, expected_output)
@ -259,7 +268,7 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
block_length=2,
num_parallel_calls=[1, 3, 5, 7]) + combinations.combine(
input_values=[np.int64([4, 5, 6, 7])],
cycle_length=dataset_ops.AUTOTUNE,
cycle_length=None,
block_length=3,
num_parallel_calls=1) + combinations.combine(
input_values=[np.int64([4, 0, 6])],
@ -278,7 +287,8 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options)
expected_output = [
element for element in _interleave(
_repeat(input_values, count), cycle_length, block_length)
_repeat(input_values, count), cycle_length, block_length,
num_parallel_calls)
]
get_next = self.getNext(dataset)
actual_output = []

View File

@ -1675,8 +1675,8 @@ name=None))
def interleave(self,
map_func,
cycle_length=AUTOTUNE,
block_length=1,
cycle_length=None,
block_length=None,
num_parallel_calls=None,
deterministic=None):
"""Maps `map_func` across this dataset, and interleaves the results.
@ -1745,12 +1745,13 @@ name=None))
Args:
map_func: A function mapping a dataset element to a dataset.
cycle_length: (Optional.) The number of input elements that will be
processed concurrently. If not specified, the value will be derived from
the number of available CPU cores. If the `num_parallel_calls` argument
is set to `tf.data.experimental.AUTOTUNE`, the `cycle_length` argument
also identifies the maximum degree of parallelism.
processed concurrently. If not set, the tf.data runtime decides what it
should be based on available CPU. If `num_parallel_calls` is set to
`tf.data.experimental.AUTOTUNE`, the `cycle_length` argument identifies
the maximum degree of parallelism.
block_length: (Optional.) The number of consecutive elements to produce
from each input element before cycling to another input element.
from each input element before cycling to another input element. If not
set, defaults to 1.
num_parallel_calls: (Optional.) If specified, the implementation creates a
threadpool, which is used to fetch inputs from cycle elements
asynchronously and in parallel. The default behavior is to fetch inputs
@ -1767,6 +1768,12 @@ name=None))
Returns:
Dataset: A `Dataset`.
"""
if block_length is None:
block_length = 1
if cycle_length is None:
cycle_length = AUTOTUNE
if num_parallel_calls is None:
return InterleaveDataset(self, map_func, cycle_length, block_length)
else:
@ -4203,6 +4210,7 @@ class InterleaveDataset(UnaryDataset):
def __init__(self, input_dataset, map_func, cycle_length, block_length):
"""See `Dataset.interleave()` for details."""
self._input_dataset = input_dataset
self._map_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)

View File

@ -58,7 +58,7 @@ tf_class {
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "list_files"

View File

@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "list_files"

View File

@ -59,7 +59,7 @@ tf_class {
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "list_files"

View File

@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "list_files"

View File

@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "list_files"

View File

@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "list_files"

View File

@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "list_files"