Add API to control parallel interleave prefetching.

PiperOrigin-RevId: 294766661
Change-Id: I8061629522d19d408cd8b7a1981836a4ee958110
This commit is contained in:
Andrew Audibert 2020-02-12 15:06:13 -08:00 committed by TensorFlower Gardener
parent a7f1d52b03
commit 0c1ca5c674
12 changed files with 406 additions and 96 deletions

View File

@ -0,0 +1,100 @@
op {
graph_op_name: "ParallelInterleaveDatasetV4"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
Dataset that produces a stream of arguments for the function `f`.
END
}
in_arg {
name: "other_arguments"
description: <<END
Additional arguments to pass to `f` beyond those produced by `input_dataset`.
Evaluated once when the dataset is instantiated.
END
}
in_arg {
name: "cycle_length"
description: <<END
Number of datasets (each created by applying `f` to the elements of
`input_dataset`) among which the `ParallelInterleaveDatasetV2` will cycle in a
round-robin fashion.
END
}
in_arg {
name: "block_length"
description: <<END
Number of elements at a time to produce from each interleaved invocation of a
dataset returned by `f`.
END
}
in_arg {
name: "buffer_output_elements"
description: <<END
The number of elements each iterator being interleaved should buffer (similar
to the `.prefetch()` transformation for each interleaved iterator).
END
}
in_arg {
name: "prefetch_input_elements"
description: <<END
Determines the number of iterators to prefetch, allowing buffers to warm up and
data to be pre-fetched without blocking the main thread.
END
}
in_arg {
name: "num_parallel_calls"
description: <<END
Determines the number of threads that should be used for fetching data from
input datasets in parallel. The Python API `tf.data.experimental.AUTOTUNE`
constant can be used to indicate that the level of parallelism should be autotuned.
END
}
attr {
name: "f"
description: <<END
A function mapping elements of `input_dataset`, concatenated with
`other_arguments`, to a Dataset variant that contains elements matching
`output_types` and `output_shapes`.
END
}
attr {
name: "deterministic"
description: <<END
A string indicating the op-level determinism to use. Deterministic controls
whether the interleave is allowed to return elements out of order if the next
element to be returned isn't available, but a later element is. Options are
"true", "false", and "default". "default" indicates that determinism should be
decided by the `experimental_deterministic` parameter of `tf.data.Options`.
END
}
attr {
name: "Targuments"
description: <<END
Types of the elements of `other_arguments`.
END
}
attr {
name: "output_types"
}
attr {
name: "output_shapes"
}
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
description: <<END
The resulting dataset is similar to the `InterleaveDataset`, except that the
dataset will fetch records from the interleaved datasets in parallel.
The `tf.data` Python API creates instances of this op from
`Dataset.interleave()` when the `num_parallel_calls` parameter of that method
is set to any value other than `None`.
By default, the output of this dataset will be deterministic, which may result
in the dataset blocking if the next data item to be returned isn't available.
In order to avoid head-of-line blocking, one can either set the `deterministic`
attribute to "false", or leave it as "default" and set the
`experimental_deterministic` parameter of `tf.data.Options` to `False`.
This can improve performance at the expense of non-determinism.
END
}

View File

@ -89,13 +89,14 @@ constexpr std::array<const char*, 28> kPassThroughOps = {
};
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
constexpr std::array<const char*, 6> kFuncDatasetOps = {
constexpr std::array<const char*, 7> kFuncDatasetOps = {
"ExperimentalParallelInterleaveDataset",
"FlatMapDataset",
"InterleaveDataset",
"ParallelInterleaveDataset",
"ParallelInterleaveDatasetV2",
"ParallelInterleaveDatasetV3"
"ParallelInterleaveDatasetV3",
"ParallelInterleaveDatasetV4"
};
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {

View File

@ -33,12 +33,10 @@ namespace {
constexpr char kLegacyAutotune[] = "legacy_autotune";
constexpr char kPrefetchDataset[] = "PrefetchDataset";
constexpr std::array<const char*, 5> kAsyncDatasetOps = {
"ExperimentalMapAndBatchDataset",
"ParallelMapDataset",
"ParallelInterleaveDatasetV2",
"ParallelInterleaveDatasetV3",
"MapAndBatchDataset",
constexpr std::array<const char*, 6> kAsyncDatasetOps = {
"ExperimentalMapAndBatchDataset", "ParallelMapDataset",
"ParallelInterleaveDatasetV2", "ParallelInterleaveDatasetV3",
"ParallelInterleaveDatasetV4", "MapAndBatchDataset",
};
} // namespace

View File

@ -32,8 +32,9 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
"ParseExampleDataset",
};
constexpr std::array<const char*, 1> kDeterministicAttrOps = {
constexpr std::array<const char*, 2> kDeterministicAttrOps = {
"ParallelInterleaveDatasetV3",
"ParallelInterleaveDatasetV4",
};
} // anonymous namespace

View File

@ -59,6 +59,10 @@ namespace data {
ParallelInterleaveDatasetOp::kCycleLength;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kBlockLength;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kBufferOutputElements;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kPrefetchInputElements;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kNumParallelCalls;
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
@ -91,34 +95,64 @@ constexpr char kSizeSuffix[] = ".size";
constexpr char kInputsSuffix[] = ".inputs";
constexpr char kIsReadySuffix[] = ".is_ready";
// `kCyclePrefetchFactor * cycle_length` is the number of future cycle elements
// that will be prefetched ahead of time. The purpose of prefetching future
// cycle elements is to overlap expensive initialization (e.g. opening of a
// remote file) with other computation.
constexpr double kCyclePrefetchFactor = 2.0L;
constexpr char kParallelInterleaveDatasetV2[] = "ParallelInterleaveDatasetV2";
constexpr char kParallelInterleaveDatasetV3[] = "ParallelInterleaveDatasetV3";
constexpr char kParallelInterleaveDatasetV4[] = "ParallelInterleaveDatasetV4";
// `kPerIteratorPrefetchFactor * block_length + 1` is the number of per-iterator
// results that will be prefetched ahead of time. The `+ 1` is to match the
// behavior of the original autotune implementation.
constexpr double kPerIteratorPrefetchFactor = 2.0L;
// `kCyclePrefetchFactor * cycle_length` is the default number of future cycle
// elements that will be prefetched ahead of time. The purpose of prefetching
// future cycle elements is to overlap expensive initialization (e.g. opening of
// a remote file) with other computation.
constexpr double kDefaultCyclePrefetchFactor = 2.0L;
// `kPerIteratorPrefetchFactor * block_length + 1` is the defualt number of
// per-iterator results that will be prefetched ahead of time. The `+ 1` is to
// match the behavior of the original implementation.
constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L;
// Period between reporting dataset statistics.
constexpr int kStatsReportingPeriodMillis = 1000;
namespace {
int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements,
int64 block_length) {
if (configured_buffer_output_elements != model::kAutotune) {
return configured_buffer_output_elements;
}
return kDefaultPerIteratorPrefetchFactor * block_length + 1;
}
int64 ComputePrefetchInputElements(int64 configured_prefetch_input_elements,
int64 cycle_length) {
if (configured_prefetch_input_elements != model::kAutotune) {
return configured_prefetch_input_elements;
}
return kDefaultCyclePrefetchFactor * cycle_length;
}
int64 OpVersionFromOpName(absl::string_view op_name) {
if (op_name == kParallelInterleaveDatasetV2) {
return 2;
} else if (op_name == kParallelInterleaveDatasetV3) {
return 3;
} else {
DCHECK_EQ(op_name, kParallelInterleaveDatasetV4);
return 4;
}
}
} // namespace
// The motivation for creating an alternative implementation of parallel
// interleave is to decouple the degree of parallelism from the cycle length.
// This makes it possible to change the degree of parallelism (e.g. through
// auto-tuning) without changing the cycle length (which would change the order
// in which elements are produced).
//
// Furthermore, this class favors modularity over extended functionality. In
// particular, it refrains from implementing configurable buffering of output
// elements and prefetching of input iterators.
class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, int64 num_parallel_calls,
int64 block_length, int64 buffer_output_elements,
int64 prefetch_input_elements, int64 num_parallel_calls,
DeterminismPolicy deterministic, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes, int op_version)
: DatasetBase(DatasetContext(ctx)),
@ -126,6 +160,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
captured_func_(std::move(captured_func)),
cycle_length_(cycle_length),
block_length_(block_length),
buffer_output_elements_(
ComputeBufferOutputElements(buffer_output_elements, block_length)),
prefetch_input_elements_(ComputePrefetchInputElements(
prefetch_input_elements, cycle_length)),
num_parallel_calls_(num_parallel_calls),
deterministic_(deterministic),
output_types_(output_types),
@ -179,19 +217,44 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
std::vector<std::pair<size_t, Node*>> inputs;
std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
int input_index = 0;
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
Node* num_parallel_calls_node;
TF_RETURN_IF_ERROR(
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
inputs.emplace_back(input_index++, input_node);
std::vector<Node*> other_arguments;
DataTypeVector other_arguments_types;
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
&other_arguments_types));
list_inputs.emplace_back(input_index++, other_arguments);
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
inputs.emplace_back(input_index++, cycle_length_node);
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
inputs.emplace_back(input_index++, block_length_node);
if (op_version_ >= 4) {
Node* buffer_output_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
inputs.emplace_back(input_index++, buffer_output_elements_node);
Node* prefetch_input_elements_node;
TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
&prefetch_input_elements_node));
inputs.emplace_back(input_index++, prefetch_input_elements_node);
}
Node* num_parallel_calls_node;
TF_RETURN_IF_ERROR(
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
inputs.emplace_back(input_index++, num_parallel_calls_node);
std::vector<std::pair<StringPiece, AttrValue>> attrs;
AttrValue f;
@ -207,18 +270,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
attrs.emplace_back(kSloppy, sloppy_attr);
}
if (op_version_ == 3) {
if (op_version_ >= 3) {
AttrValue deterministic_attr;
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
attrs.emplace_back(kDeterministic, deterministic_attr);
}
TF_RETURN_IF_ERROR(b->AddDataset(this,
{{0, input_node},
{2, cycle_length_node},
{3, block_length_node},
{4, num_parallel_calls_node}},
{{1, other_arguments}}, attrs, output));
TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
return Status::OK();
}
@ -227,12 +285,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
public:
ParallelInterleaveIterator(const Params& params, bool deterministic)
: DatasetIterator<Dataset>(params),
per_iterator_prefetch_(
static_cast<int>(params.dataset->block_length_ *
kPerIteratorPrefetchFactor) +
1),
future_elements_prefetch_(static_cast<int>(
params.dataset->cycle_length_ * kCyclePrefetchFactor)),
mu_(std::make_shared<mutex>()),
num_parallel_calls_cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
@ -257,7 +309,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// collection, `cycle_length_` threads for the current workers, and
// `future_elements_prefetch_` for the future workers.
int max_current_workers = dataset()->cycle_length_;
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
int future_workers =
dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
int num_threads = 1 + max_current_workers + future_workers;
if (ctx->stats_aggregator()) {
num_threads++;
@ -656,7 +709,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// `cycle_length_` future workers to guarantee that whenever
// `future_element_.size() < future_elements_prefetch_`, there will be a
// future worker available to create a new future element.
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
int future_workers =
dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
{
mutex_lock l(*mu_);
initial_current_workers = num_parallel_calls_->value;
@ -800,9 +854,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
current_workers_cond_var_.notify_one();
}
}
while (!cancelled_ &&
(future_elements_.size() >= future_elements_prefetch_ ||
wait_for_checkpoint_)) {
while (!cancelled_ && (future_elements_.size() >=
dataset()->prefetch_input_elements_ ||
wait_for_checkpoint_)) {
WaitWorkerThread(&future_workers_cond_var_, &l);
}
if (cancelled_) {
@ -860,7 +914,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
mutex_lock l(*mu_);
element->results.push_back(std::move(result));
NotifyElementUpdate(element);
if (element->results.size() == per_iterator_prefetch_) {
if (element->results.size() == dataset()->buffer_output_elements_) {
break;
}
}
@ -954,7 +1008,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
return true;
}
return element->iterator &&
element->results.size() < per_iterator_prefetch_;
element->results.size() < dataset()->buffer_output_elements_;
}
inline void IncrementCurrentWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
@ -1311,9 +1365,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// its elements are null.
int64 last_valid_current_element_ GUARDED_BY(mu_) = -1;
const int per_iterator_prefetch_;
const int future_elements_prefetch_;
// Identifies whether the current_elements_ vector has been initialized.
bool initial_elements_created_ GUARDED_BY(mu_) = false;
@ -1421,6 +1472,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_;
const int64 block_length_;
const int64 buffer_output_elements_;
const int64 prefetch_input_elements_;
const int64 num_parallel_calls_;
const DeterminismPolicy deterministic_;
const DataTypeVector output_types_;
@ -1431,7 +1484,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx), op_version_(ctx->HasAttr(kSloppy) ? 2 : 3) {
: UnaryDatasetOpKernel(ctx),
op_version_(OpVersionFromOpName(ctx->def().op())) {
FunctionMetadata::Params params;
params.is_multi_device_function = true;
OP_REQUIRES_OK(ctx,
@ -1448,7 +1502,7 @@ ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
}
}
if (op_version_ == 3) {
if (op_version_ >= 3) {
std::string deterministic;
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
OP_REQUIRES_OK(
@ -1472,6 +1526,26 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0"));
int64 buffer_output_elements = model::kAutotune;
int64 prefetch_input_elements = model::kAutotune;
if (op_version_ >= 4) {
OP_REQUIRES(ctx,
buffer_output_elements == model::kAutotune ||
buffer_output_elements >= 0,
errors::InvalidArgument("`buffer_output_elements` must be ",
model::kAutotune, " or >= 0 but is ",
buffer_output_elements));
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
&prefetch_input_elements));
OP_REQUIRES(ctx,
prefetch_input_elements == model::kAutotune ||
prefetch_input_elements >= 0,
errors::InvalidArgument("`prefetch_input_elements` must be ",
model::kAutotune, " or >= 0 but is ",
prefetch_input_elements));
}
int64 num_parallel_calls = 0;
OP_REQUIRES_OK(
ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
@ -1492,18 +1566,22 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
metrics::RecordTFDataAutotune(kDatasetType);
}
*output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
block_length, num_parallel_calls, deterministic_,
output_types_, output_shapes_, op_version_);
*output = new Dataset(
ctx, input, std::move(captured_func), cycle_length, block_length,
buffer_output_elements, prefetch_input_elements, num_parallel_calls,
deterministic_, output_types_, output_shapes_, op_version_);
}
namespace {
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV2).Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV3").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV3).Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV2");
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV3");
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV4).Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV2);
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV3);
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV4);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -29,6 +29,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
static constexpr const char* const kOtherArguments = "other_arguments";
static constexpr const char* const kCycleLength = "cycle_length";
static constexpr const char* const kBlockLength = "block_length";
static constexpr const char* const kBufferOutputElements =
"buffer_output_elements";
static constexpr const char* const kPrefetchInputElements =
"prefetch_input_elements";
static constexpr const char* const kNumParallelCalls = "num_parallel_calls";
static constexpr const char* const kFunc = "f";
static constexpr const char* const kTarguments = "Targuments";

View File

@ -18,14 +18,15 @@ namespace data {
namespace {
constexpr char kNodeName[] = "parallel_interleave_dataset";
constexpr int kOpVersion = 3;
constexpr int kOpVersion = 4;
class ParallelInterleaveDatasetParams : public DatasetParams {
public:
template <typename T>
ParallelInterleaveDatasetParams(
T input_dataset_params, std::vector<Tensor> other_arguments,
int64 cycle_length, int64 block_length, int64 num_parallel_calls,
int64 cycle_length, int64 block_length, int64 buffer_output_elements,
int64 prefetch_input_elements, int64 num_parallel_calls,
FunctionDefHelper::AttrValueWrapper func,
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
const DataTypeVector& output_dtypes,
@ -36,6 +37,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
other_arguments_(std::move(other_arguments)),
cycle_length_(cycle_length),
block_length_(block_length),
buffer_output_elements_(buffer_output_elements),
prefetch_input_elements_(prefetch_input_elements),
num_parallel_calls_(num_parallel_calls),
func_(std::move(func)),
func_lib_(std::move(func_lib)),
@ -56,6 +59,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
CreateTensor<int64>(TensorShape({}), {cycle_length_}));
input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {block_length_}));
input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {buffer_output_elements_}));
input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {prefetch_input_elements_}));
input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {num_parallel_calls_}));
return input_tensors;
@ -69,6 +76,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
}
input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength);
input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength);
input_names->emplace_back(
ParallelInterleaveDatasetOp::kBufferOutputElements);
input_names->emplace_back(
ParallelInterleaveDatasetOp::kPrefetchInputElements);
input_names->emplace_back(ParallelInterleaveDatasetOp::kNumParallelCalls);
return Status::OK();
}
@ -93,6 +104,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
std::vector<Tensor> other_arguments_;
int64 cycle_length_;
int64 block_length_;
int64 buffer_output_elements_;
int64 prefetch_input_elements_;
int64 num_parallel_calls_;
FunctionDefHelper::AttrValueWrapper func_;
std::vector<FunctionDef> func_lib_;
@ -111,8 +124,6 @@ FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc(
{"output_shapes", output_shapes}});
}
// test case 1: cycle_length = 1, block_length = 1, num_parallel_calls = 1,
// sloppy = false
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
@ -123,6 +134,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/1,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -136,8 +149,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
/*node_name=*/kNodeName);
}
// test case 2: cycle_length = 2, block_length = 1, num_parallel_calls = 2,
// sloppy = false
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
@ -148,6 +159,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
/*other_arguments=*/{},
/*cycle_length=*/2,
/*block_length=*/1,
/*buffer_output_elements=*/0,
/*prefetch_input_elements=*/0,
/*num_parallel_calls=*/2,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -161,8 +174,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
/*node_name=*/kNodeName);
}
// test case 3: cycle_length = 3, block_length = 1, num_parallel_calls = 2,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
@ -173,6 +184,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
/*other_arguments=*/{},
/*cycle_length=*/3,
/*block_length=*/1,
/*buffer_output_elements=*/0,
/*prefetch_input_elements=*/1,
/*num_parallel_calls=*/2,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -186,9 +199,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
/*node_name=*/kNodeName);
}
// test case 4: cycle_length = 5, block_length = 1, num_parallel_calls = 4,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
@ -199,6 +209,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
/*other_arguments=*/{},
/*cycle_length=*/5,
/*block_length=*/1,
/*buffer_output_elements=*/1,
/*prefetch_input_elements=*/0,
/*num_parallel_calls=*/4,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -212,8 +224,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
/*node_name=*/kNodeName);
}
// test case 5: cycle_length = 2, block_length = 2, num_parallel_calls = 1,
// sloppy = false
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>(
@ -224,6 +234,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
/*other_arguments=*/{},
/*cycle_length=*/2,
/*block_length=*/2,
/*buffer_output_elements=*/2,
/*prefetch_input_elements=*/2,
/*num_parallel_calls=*/1,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -237,8 +249,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
/*node_name=*/kNodeName);
}
// test case 6: cycle_length = 2, block_length = 3, num_parallel_calls = 2,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>(
@ -249,6 +259,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
/*other_arguments=*/{},
/*cycle_length=*/2,
/*block_length=*/3,
/*buffer_output_elements=*/100,
/*prefetch_input_elements=*/100,
/*num_parallel_calls=*/2,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -262,8 +274,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
/*node_name=*/kNodeName);
}
// test case 7: cycle_length = 3, block_length = 2, num_parallel_calls = 2,
// sloppy = false
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>(
@ -274,6 +284,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
/*other_arguments=*/{},
/*cycle_length=*/3,
/*block_length=*/2,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/2,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -287,8 +299,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
/*node_name=*/kNodeName);
}
// test case 8: cycle_length = 3, block_length = 3, num_parallel_calls = 3,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>(
@ -299,6 +309,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
/*other_arguments=*/{},
/*cycle_length=*/3,
/*block_length=*/3,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/3,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -312,8 +324,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
/*node_name=*/kNodeName);
}
// test case 9: cycle_length = 4, block_length = 4, num_parallel_calls = 4,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>(
@ -324,6 +334,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
/*other_arguments=*/{},
/*cycle_length=*/4,
/*block_length=*/4,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/4,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -337,8 +349,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
/*node_name=*/kNodeName);
}
// test case 10: cycle_length = 3, block_length = 3,
// num_parallel_calls = kAutotune, sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>(
@ -349,6 +359,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
/*other_arguments=*/{},
/*cycle_length=*/4,
/*block_length=*/4,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/model::kAutotune,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -372,6 +384,8 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
/*other_arguments=*/{},
/*cycle_length=*/11,
/*block_length=*/1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/2,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -385,8 +399,6 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
/*node_name=*/kNodeName);
}
// test case 11: cycle_length = 0, block_length = 1, num_parallel_calls = 2,
// sloppy = true
ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -398,6 +410,8 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
/*other_arguments=*/{},
/*cycle_length=*/0,
/*block_length=*/1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/2,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -411,8 +425,6 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
/*node_name=*/kNodeName);
}
// test case 12: cycle_length = 1, block_length = -1, num_parallel_calls = 2,
// sloppy = true
ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -424,6 +436,8 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/-1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/2,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -437,8 +451,6 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
/*node_name=*/kNodeName);
}
// test case 13: cycle_length = 1, block_length = 1, num_parallel_calls = -5,
// sloppy = true
ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -450,6 +462,60 @@ ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/-5,
/*func=*/
MakeTensorSliceDatasetFunc(
DataTypeVector({DT_INT64}),
std::vector<PartialTensorShape>({PartialTensorShape({1})})),
/*func_lib=*/{test::function::MakeTensorSliceDataset()},
/*type_arguments=*/{},
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({1})},
/*deterministic=*/DeterminismPolicy::kNondeterministic,
/*node_name=*/kNodeName);
}
ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8})},
/*node_name=*/"tensor_slice");
return ParallelInterleaveDatasetParams(
tensor_slice_dataset_params,
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/1,
/*buffer_output_elements=*/-2,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/-5,
/*func=*/
MakeTensorSliceDatasetFunc(
DataTypeVector({DT_INT64}),
std::vector<PartialTensorShape>({PartialTensorShape({1})})),
/*func_lib=*/{test::function::MakeTensorSliceDataset()},
/*type_arguments=*/{},
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({1})},
/*deterministic=*/DeterminismPolicy::kNondeterministic,
/*node_name=*/kNodeName);
}
ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8})},
/*node_name=*/"tensor_slice");
return ParallelInterleaveDatasetParams(
tensor_slice_dataset_params,
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/1,
/*buffer_output_elements=*/-2,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/-5,
/*func=*/
MakeTensorSliceDatasetFunc(
@ -698,7 +764,10 @@ TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) {
std::vector<ParallelInterleaveDatasetParams> invalid_params = {
ParallelInterleaveDatasetParamsWithInvalidCycleLength(),
ParallelInterleaveDatasetParamsWithInvalidBlockLength(),
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls()};
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls(),
ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements(),
ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements(),
};
for (auto& dataset_params : invalid_params) {
EXPECT_EQ(Initialize(dataset_params).code(),
tensorflow::error::INVALID_ARGUMENT);

View File

@ -227,6 +227,24 @@ REGISTER_OP("ParallelInterleaveDatasetV3")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
// Like V3, but adds buffer_output_elements and prefetch_input_elements.
REGISTER_OP("ParallelInterleaveDatasetV4")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Input("cycle_length: int64")
.Input("block_length: int64")
.Input("buffer_output_elements: int64")
.Input("prefetch_input_elements: int64")
.Input("num_parallel_calls: int64")
.Output("handle: variant")
.Attr("f: func")
// "true", "false", or "default".
.Attr("deterministic: string = 'default'")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("FilterDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")

View File

@ -64,6 +64,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
parallel_interleave = "ParallelInterleaveV2"
if compat.forward_compatible(2020, 2, 20):
parallel_interleave = "ParallelInterleaveV3"
if compat.forward_compatible(2020, 3, 6):
parallel_interleave = "ParallelInterleaveV4"
dataset = dataset.apply(
testing.assert_next([parallel_interleave, "Prefetch", "FiniteTake"]))
dataset = dataset.interleave(
@ -79,6 +81,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
parallel_interleave = "ParallelInterleaveV2"
if compat.forward_compatible(2020, 2, 20):
parallel_interleave = "ParallelInterleaveV3"
if compat.forward_compatible(2020, 3, 6):
parallel_interleave = "ParallelInterleaveV4"
dataset = dataset.apply(
testing.assert_next([
"ParallelMap", "Prefetch", parallel_interleave, "Prefetch",

View File

@ -1714,9 +1714,13 @@ name=None))
if num_parallel_calls is None:
return InterleaveDataset(self, map_func, cycle_length, block_length)
else:
return ParallelInterleaveDataset(self, map_func, cycle_length,
block_length, num_parallel_calls,
deterministic)
return ParallelInterleaveDataset(
self,
map_func,
cycle_length,
block_length,
num_parallel_calls,
deterministic=deterministic)
def filter(self, predicate):
"""Filters this dataset according to `predicate`.
@ -4042,6 +4046,8 @@ class ParallelInterleaveDataset(UnaryDataset):
cycle_length,
block_length,
num_parallel_calls,
buffer_output_elements=AUTOTUNE,
prefetch_input_elements=AUTOTUNE,
deterministic=None):
"""See `Dataset.interleave()` for details."""
self._input_dataset = input_dataset
@ -4056,6 +4062,15 @@ class ParallelInterleaveDataset(UnaryDataset):
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length")
self._buffer_output_elements = ops.convert_to_tensor(
buffer_output_elements,
dtype=dtypes.int64,
name="buffer_output_elements")
self._prefetch_input_elements = ops.convert_to_tensor(
prefetch_input_elements,
dtype=dtypes.int64,
name="prefetch_input_elements")
self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
if deterministic is None:
@ -4065,7 +4080,21 @@ class ParallelInterleaveDataset(UnaryDataset):
else:
deterministic_string = "false"
if deterministic is not None or compat.forward_compatible(2020, 2, 20):
if (buffer_output_elements != AUTOTUNE or
prefetch_input_elements != AUTOTUNE or
compat.forward_compatible(2020, 3, 6)):
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs, # pylint: disable=protected-access
self._cycle_length,
self._block_length,
self._buffer_output_elements,
self._prefetch_input_elements,
self._num_parallel_calls,
f=self._map_func.function,
deterministic=deterministic_string,
**self._flat_structure)
elif deterministic is not None or compat.forward_compatible(2020, 2, 20):
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs, # pylint: disable=protected-access

View File

@ -2636,6 +2636,10 @@ tf_module {
name: "ParallelInterleaveDatasetV3"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method {
name: "ParallelInterleaveDatasetV4"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method {
name: "ParallelMapDataset"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "

View File

@ -2636,6 +2636,10 @@ tf_module {
name: "ParallelInterleaveDatasetV3"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method {
name: "ParallelInterleaveDatasetV4"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method {
name: "ParallelMapDataset"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "