[tf.data] Replacing default arguments values for tf.data.Dataset.interleave
API with None
, while preserving original behavior.
PiperOrigin-RevId: 305809806 Change-Id: I7471c4d787014d335b393346a356d9ebcdf845ff
This commit is contained in:
parent
51ae56daae
commit
c003ecc1d6
@ -76,6 +76,8 @@ namespace data {
|
||||
ParallelInterleaveDatasetOp::kDeterministic;
|
||||
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTfDataParallelInterleaveWorkerPool[] =
|
||||
"tf_data_parallel_interleave_worker_pool";
|
||||
constexpr char kParallelism[] = "parallelism";
|
||||
@ -113,7 +115,10 @@ constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L;
|
||||
// Period between reporting dataset statistics.
|
||||
constexpr int kStatsReportingPeriodMillis = 1000;
|
||||
|
||||
namespace {
|
||||
inline int64 CeilDiv(int64 numerator, int64 denominator) {
|
||||
return (numerator + denominator - 1) / denominator;
|
||||
}
|
||||
|
||||
int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements,
|
||||
int64 block_length) {
|
||||
if (configured_buffer_output_elements != model::kAutotune) {
|
||||
@ -140,6 +145,7 @@ int64 OpVersionFromOpName(absl::string_view op_name) {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// The motivation for creating an alternative implementation of parallel
|
||||
@ -1522,14 +1528,6 @@ ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
|
||||
void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
DatasetBase* input,
|
||||
DatasetBase** output) {
|
||||
int64 cycle_length = 0;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
|
||||
if (cycle_length == model::kAutotune) {
|
||||
cycle_length = port::NumSchedulableCPUs();
|
||||
}
|
||||
OP_REQUIRES(ctx, cycle_length > 0,
|
||||
errors::InvalidArgument("`cycle_length` must be > 0"));
|
||||
|
||||
int64 block_length = 0;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
|
||||
OP_REQUIRES(ctx, block_length > 0,
|
||||
@ -1561,6 +1559,24 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
OP_REQUIRES(
|
||||
ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
|
||||
errors::InvalidArgument("num_parallel_calls must be greater than zero."));
|
||||
int64 cycle_length = 0;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
|
||||
if (cycle_length == model::kAutotune) {
|
||||
if (num_parallel_calls != model::kAutotune) {
|
||||
cycle_length = std::min(num_parallel_calls,
|
||||
static_cast<int64>(port::MaxParallelism()));
|
||||
} else {
|
||||
// If parallelism is to be autotuned, we set the cycle length so that
|
||||
// the number of thread created for the current and future cycle elements
|
||||
// roughly matches the number of schedulable cores.
|
||||
const int num_threads_per_cycle_length = kDefaultCyclePrefetchFactor + 1;
|
||||
cycle_length =
|
||||
CeilDiv(port::MaxParallelism(), num_threads_per_cycle_length);
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx, cycle_length > 0,
|
||||
errors::InvalidArgument("`cycle_length` must be > 0"));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, num_parallel_calls <= cycle_length,
|
||||
errors::InvalidArgument(
|
||||
|
@ -35,13 +35,14 @@ from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _interleave(lists, cycle_length, block_length):
|
||||
def _interleave(lists, cycle_length, block_length, num_parallel_calls=None):
|
||||
"""Reference implementation of interleave used for testing.
|
||||
|
||||
Args:
|
||||
lists: a list of lists to interleave
|
||||
cycle_length: the length of the interleave cycle
|
||||
block_length: the length of the interleave block
|
||||
num_parallel_calls: the number of parallel calls
|
||||
|
||||
Yields:
|
||||
Elements of `lists` interleaved in the order determined by `cycle_length`
|
||||
@ -55,8 +56,15 @@ def _interleave(lists, cycle_length, block_length):
|
||||
# `open_iterators` are the iterators whose elements are currently being
|
||||
# interleaved.
|
||||
open_iterators = []
|
||||
if cycle_length == dataset_ops.AUTOTUNE:
|
||||
cycle_length = multiprocessing.cpu_count()
|
||||
if cycle_length is None:
|
||||
# The logic here needs to match interleave C++ kernels.
|
||||
if num_parallel_calls is None:
|
||||
cycle_length = multiprocessing.cpu_count()
|
||||
elif num_parallel_calls == dataset_ops.AUTOTUNE:
|
||||
cycle_length = (multiprocessing.cpu_count() + 2) // 3
|
||||
else:
|
||||
cycle_length = min(num_parallel_calls, multiprocessing.cpu_count())
|
||||
|
||||
for i in range(cycle_length):
|
||||
if all_iterators:
|
||||
open_iterators.append(all_iterators.pop(0))
|
||||
@ -162,7 +170,7 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
num_parallel_calls=[None, 1, 3, 5, 7]) +
|
||||
combinations.combine(
|
||||
input_values=[np.int64([4, 5, 6, 7])],
|
||||
cycle_length=dataset_ops.AUTOTUNE,
|
||||
cycle_length=None,
|
||||
block_length=3,
|
||||
num_parallel_calls=[None, 1]) + combinations.combine(
|
||||
input_values=[np.int64([]), np.int64([0, 0, 0])],
|
||||
@ -182,7 +190,8 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
cycle_length, block_length, num_parallel_calls)
|
||||
expected_output = [
|
||||
element for element in _interleave(
|
||||
_repeat(input_values, count), cycle_length, block_length)
|
||||
_repeat(input_values, count), cycle_length, block_length,
|
||||
num_parallel_calls)
|
||||
]
|
||||
self.assertDatasetProduces(dataset, expected_output)
|
||||
|
||||
@ -259,7 +268,7 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
block_length=2,
|
||||
num_parallel_calls=[1, 3, 5, 7]) + combinations.combine(
|
||||
input_values=[np.int64([4, 5, 6, 7])],
|
||||
cycle_length=dataset_ops.AUTOTUNE,
|
||||
cycle_length=None,
|
||||
block_length=3,
|
||||
num_parallel_calls=1) + combinations.combine(
|
||||
input_values=[np.int64([4, 0, 6])],
|
||||
@ -278,7 +287,8 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.with_options(options)
|
||||
expected_output = [
|
||||
element for element in _interleave(
|
||||
_repeat(input_values, count), cycle_length, block_length)
|
||||
_repeat(input_values, count), cycle_length, block_length,
|
||||
num_parallel_calls)
|
||||
]
|
||||
get_next = self.getNext(dataset)
|
||||
actual_output = []
|
||||
|
@ -1675,8 +1675,8 @@ name=None))
|
||||
|
||||
def interleave(self,
|
||||
map_func,
|
||||
cycle_length=AUTOTUNE,
|
||||
block_length=1,
|
||||
cycle_length=None,
|
||||
block_length=None,
|
||||
num_parallel_calls=None,
|
||||
deterministic=None):
|
||||
"""Maps `map_func` across this dataset, and interleaves the results.
|
||||
@ -1745,12 +1745,13 @@ name=None))
|
||||
Args:
|
||||
map_func: A function mapping a dataset element to a dataset.
|
||||
cycle_length: (Optional.) The number of input elements that will be
|
||||
processed concurrently. If not specified, the value will be derived from
|
||||
the number of available CPU cores. If the `num_parallel_calls` argument
|
||||
is set to `tf.data.experimental.AUTOTUNE`, the `cycle_length` argument
|
||||
also identifies the maximum degree of parallelism.
|
||||
processed concurrently. If not set, the tf.data runtime decides what it
|
||||
should be based on available CPU. If `num_parallel_calls` is set to
|
||||
`tf.data.experimental.AUTOTUNE`, the `cycle_length` argument identifies
|
||||
the maximum degree of parallelism.
|
||||
block_length: (Optional.) The number of consecutive elements to produce
|
||||
from each input element before cycling to another input element.
|
||||
from each input element before cycling to another input element. If not
|
||||
set, defaults to 1.
|
||||
num_parallel_calls: (Optional.) If specified, the implementation creates a
|
||||
threadpool, which is used to fetch inputs from cycle elements
|
||||
asynchronously and in parallel. The default behavior is to fetch inputs
|
||||
@ -1767,6 +1768,12 @@ name=None))
|
||||
Returns:
|
||||
Dataset: A `Dataset`.
|
||||
"""
|
||||
if block_length is None:
|
||||
block_length = 1
|
||||
|
||||
if cycle_length is None:
|
||||
cycle_length = AUTOTUNE
|
||||
|
||||
if num_parallel_calls is None:
|
||||
return InterleaveDataset(self, map_func, cycle_length, block_length)
|
||||
else:
|
||||
@ -4203,6 +4210,7 @@ class InterleaveDataset(UnaryDataset):
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length):
|
||||
"""See `Dataset.interleave()` for details."""
|
||||
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
|
@ -58,7 +58,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -59,7 +59,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
Loading…
Reference in New Issue
Block a user