Add API to control parallel interleave prefetching.

PiperOrigin-RevId: 294766661
Change-Id: I8061629522d19d408cd8b7a1981836a4ee958110
This commit is contained in:
Andrew Audibert 2020-02-12 15:06:13 -08:00 committed by TensorFlower Gardener
parent a7f1d52b03
commit 0c1ca5c674
12 changed files with 406 additions and 96 deletions

View File

@ -0,0 +1,100 @@
op {
graph_op_name: "ParallelInterleaveDatasetV4"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
Dataset that produces a stream of arguments for the function `f`.
END
}
in_arg {
name: "other_arguments"
description: <<END
Additional arguments to pass to `f` beyond those produced by `input_dataset`.
Evaluated once when the dataset is instantiated.
END
}
in_arg {
name: "cycle_length"
description: <<END
Number of datasets (each created by applying `f` to the elements of
`input_dataset`) among which the `ParallelInterleaveDatasetV2` will cycle in a
round-robin fashion.
END
}
in_arg {
name: "block_length"
description: <<END
Number of elements at a time to produce from each interleaved invocation of a
dataset returned by `f`.
END
}
in_arg {
name: "buffer_output_elements"
description: <<END
The number of elements each iterator being interleaved should buffer (similar
to the `.prefetch()` transformation for each interleaved iterator).
END
}
in_arg {
name: "prefetch_input_elements"
description: <<END
Determines the number of iterators to prefetch, allowing buffers to warm up and
data to be pre-fetched without blocking the main thread.
END
}
in_arg {
name: "num_parallel_calls"
description: <<END
Determines the number of threads that should be used for fetching data from
input datasets in parallel. The Python API `tf.data.experimental.AUTOTUNE`
constant can be used to indicate that the level of parallelism should be autotuned.
END
}
attr {
name: "f"
description: <<END
A function mapping elements of `input_dataset`, concatenated with
`other_arguments`, to a Dataset variant that contains elements matching
`output_types` and `output_shapes`.
END
}
attr {
name: "deterministic"
description: <<END
A string indicating the op-level determinism to use. Deterministic controls
whether the interleave is allowed to return elements out of order if the next
element to be returned isn't available, but a later element is. Options are
"true", "false", and "default". "default" indicates that determinism should be
decided by the `experimental_deterministic` parameter of `tf.data.Options`.
END
}
attr {
name: "Targuments"
description: <<END
Types of the elements of `other_arguments`.
END
}
attr {
name: "output_types"
}
attr {
name: "output_shapes"
}
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
description: <<END
The resulting dataset is similar to the `InterleaveDataset`, except that the
dataset will fetch records from the interleaved datasets in parallel.
The `tf.data` Python API creates instances of this op from
`Dataset.interleave()` when the `num_parallel_calls` parameter of that method
is set to any value other than `None`.
By default, the output of this dataset will be deterministic, which may result
in the dataset blocking if the next data item to be returned isn't available.
In order to avoid head-of-line blocking, one can either set the `deterministic`
attribute to "false", or leave it as "default" and set the
`experimental_deterministic` parameter of `tf.data.Options` to `False`.
This can improve performance at the expense of non-determinism.
END
}

View File

@ -89,13 +89,14 @@ constexpr std::array<const char*, 28> kPassThroughOps = {
}; };
// TODO(frankchn): Process functions within kFuncDatasetOps as well. // TODO(frankchn): Process functions within kFuncDatasetOps as well.
constexpr std::array<const char*, 6> kFuncDatasetOps = { constexpr std::array<const char*, 7> kFuncDatasetOps = {
"ExperimentalParallelInterleaveDataset", "ExperimentalParallelInterleaveDataset",
"FlatMapDataset", "FlatMapDataset",
"InterleaveDataset", "InterleaveDataset",
"ParallelInterleaveDataset", "ParallelInterleaveDataset",
"ParallelInterleaveDatasetV2", "ParallelInterleaveDatasetV2",
"ParallelInterleaveDatasetV3" "ParallelInterleaveDatasetV3",
"ParallelInterleaveDatasetV4"
}; };
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = { constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {

View File

@ -33,12 +33,10 @@ namespace {
constexpr char kLegacyAutotune[] = "legacy_autotune"; constexpr char kLegacyAutotune[] = "legacy_autotune";
constexpr char kPrefetchDataset[] = "PrefetchDataset"; constexpr char kPrefetchDataset[] = "PrefetchDataset";
constexpr std::array<const char*, 5> kAsyncDatasetOps = { constexpr std::array<const char*, 6> kAsyncDatasetOps = {
"ExperimentalMapAndBatchDataset", "ExperimentalMapAndBatchDataset", "ParallelMapDataset",
"ParallelMapDataset", "ParallelInterleaveDatasetV2", "ParallelInterleaveDatasetV3",
"ParallelInterleaveDatasetV2", "ParallelInterleaveDatasetV4", "MapAndBatchDataset",
"ParallelInterleaveDatasetV3",
"MapAndBatchDataset",
}; };
} // namespace } // namespace

View File

@ -32,8 +32,9 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
"ParseExampleDataset", "ParseExampleDataset",
}; };
constexpr std::array<const char*, 1> kDeterministicAttrOps = { constexpr std::array<const char*, 2> kDeterministicAttrOps = {
"ParallelInterleaveDatasetV3", "ParallelInterleaveDatasetV3",
"ParallelInterleaveDatasetV4",
}; };
} // anonymous namespace } // anonymous namespace

View File

@ -59,6 +59,10 @@ namespace data {
ParallelInterleaveDatasetOp::kCycleLength; ParallelInterleaveDatasetOp::kCycleLength;
/* static */ constexpr const char* const /* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kBlockLength; ParallelInterleaveDatasetOp::kBlockLength;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kBufferOutputElements;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kPrefetchInputElements;
/* static */ constexpr const char* const /* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kNumParallelCalls; ParallelInterleaveDatasetOp::kNumParallelCalls;
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc; /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
@ -91,34 +95,64 @@ constexpr char kSizeSuffix[] = ".size";
constexpr char kInputsSuffix[] = ".inputs"; constexpr char kInputsSuffix[] = ".inputs";
constexpr char kIsReadySuffix[] = ".is_ready"; constexpr char kIsReadySuffix[] = ".is_ready";
// `kCyclePrefetchFactor * cycle_length` is the number of future cycle elements constexpr char kParallelInterleaveDatasetV2[] = "ParallelInterleaveDatasetV2";
// that will be prefetched ahead of time. The purpose of prefetching future constexpr char kParallelInterleaveDatasetV3[] = "ParallelInterleaveDatasetV3";
// cycle elements is to overlap expensive initialization (e.g. opening of a constexpr char kParallelInterleaveDatasetV4[] = "ParallelInterleaveDatasetV4";
// remote file) with other computation.
constexpr double kCyclePrefetchFactor = 2.0L;
// `kPerIteratorPrefetchFactor * block_length + 1` is the number of per-iterator // `kCyclePrefetchFactor * cycle_length` is the default number of future cycle
// results that will be prefetched ahead of time. The `+ 1` is to match the // elements that will be prefetched ahead of time. The purpose of prefetching
// behavior of the original autotune implementation. // future cycle elements is to overlap expensive initialization (e.g. opening of
constexpr double kPerIteratorPrefetchFactor = 2.0L; // a remote file) with other computation.
constexpr double kDefaultCyclePrefetchFactor = 2.0L;
// `kPerIteratorPrefetchFactor * block_length + 1` is the defualt number of
// per-iterator results that will be prefetched ahead of time. The `+ 1` is to
// match the behavior of the original implementation.
constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L;
// Period between reporting dataset statistics. // Period between reporting dataset statistics.
constexpr int kStatsReportingPeriodMillis = 1000; constexpr int kStatsReportingPeriodMillis = 1000;
namespace {
int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements,
int64 block_length) {
if (configured_buffer_output_elements != model::kAutotune) {
return configured_buffer_output_elements;
}
return kDefaultPerIteratorPrefetchFactor * block_length + 1;
}
int64 ComputePrefetchInputElements(int64 configured_prefetch_input_elements,
int64 cycle_length) {
if (configured_prefetch_input_elements != model::kAutotune) {
return configured_prefetch_input_elements;
}
return kDefaultCyclePrefetchFactor * cycle_length;
}
int64 OpVersionFromOpName(absl::string_view op_name) {
if (op_name == kParallelInterleaveDatasetV2) {
return 2;
} else if (op_name == kParallelInterleaveDatasetV3) {
return 3;
} else {
DCHECK_EQ(op_name, kParallelInterleaveDatasetV4);
return 4;
}
}
} // namespace
// The motivation for creating an alternative implementation of parallel // The motivation for creating an alternative implementation of parallel
// interleave is to decouple the degree of parallelism from the cycle length. // interleave is to decouple the degree of parallelism from the cycle length.
// This makes it possible to change the degree of parallelism (e.g. through // This makes it possible to change the degree of parallelism (e.g. through
// auto-tuning) without changing the cycle length (which would change the order // auto-tuning) without changing the cycle length (which would change the order
// in which elements are produced). // in which elements are produced).
//
// Furthermore, this class favors modularity over extended functionality. In
// particular, it refrains from implementing configurable buffering of output
// elements and prefetching of input iterators.
class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
public: public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, Dataset(OpKernelContext* ctx, const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, int64 num_parallel_calls, int64 block_length, int64 buffer_output_elements,
int64 prefetch_input_elements, int64 num_parallel_calls,
DeterminismPolicy deterministic, const DataTypeVector& output_types, DeterminismPolicy deterministic, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes, int op_version) const std::vector<PartialTensorShape>& output_shapes, int op_version)
: DatasetBase(DatasetContext(ctx)), : DatasetBase(DatasetContext(ctx)),
@ -126,6 +160,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
captured_func_(std::move(captured_func)), captured_func_(std::move(captured_func)),
cycle_length_(cycle_length), cycle_length_(cycle_length),
block_length_(block_length), block_length_(block_length),
buffer_output_elements_(
ComputeBufferOutputElements(buffer_output_elements, block_length)),
prefetch_input_elements_(ComputePrefetchInputElements(
prefetch_input_elements, cycle_length)),
num_parallel_calls_(num_parallel_calls), num_parallel_calls_(num_parallel_calls),
deterministic_(deterministic), deterministic_(deterministic),
output_types_(output_types), output_types_(output_types),
@ -179,19 +217,44 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
Node** output) const override { Node** output) const override {
std::vector<std::pair<size_t, Node*>> inputs;
std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
int input_index = 0;
Node* input_node; Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node; inputs.emplace_back(input_index++, input_node);
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
Node* num_parallel_calls_node;
TF_RETURN_IF_ERROR(
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
std::vector<Node*> other_arguments; std::vector<Node*> other_arguments;
DataTypeVector other_arguments_types; DataTypeVector other_arguments_types;
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
&other_arguments_types)); &other_arguments_types));
list_inputs.emplace_back(input_index++, other_arguments);
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
inputs.emplace_back(input_index++, cycle_length_node);
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
inputs.emplace_back(input_index++, block_length_node);
if (op_version_ >= 4) {
Node* buffer_output_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
inputs.emplace_back(input_index++, buffer_output_elements_node);
Node* prefetch_input_elements_node;
TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
&prefetch_input_elements_node));
inputs.emplace_back(input_index++, prefetch_input_elements_node);
}
Node* num_parallel_calls_node;
TF_RETURN_IF_ERROR(
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
inputs.emplace_back(input_index++, num_parallel_calls_node);
std::vector<std::pair<StringPiece, AttrValue>> attrs; std::vector<std::pair<StringPiece, AttrValue>> attrs;
AttrValue f; AttrValue f;
@ -207,18 +270,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr); b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
attrs.emplace_back(kSloppy, sloppy_attr); attrs.emplace_back(kSloppy, sloppy_attr);
} }
if (op_version_ == 3) { if (op_version_ >= 3) {
AttrValue deterministic_attr; AttrValue deterministic_attr;
b->BuildAttrValue(deterministic_.String(), &deterministic_attr); b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
attrs.emplace_back(kDeterministic, deterministic_attr); attrs.emplace_back(kDeterministic, deterministic_attr);
} }
TF_RETURN_IF_ERROR(b->AddDataset(this, TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
{{0, input_node},
{2, cycle_length_node},
{3, block_length_node},
{4, num_parallel_calls_node}},
{{1, other_arguments}}, attrs, output));
return Status::OK(); return Status::OK();
} }
@ -227,12 +285,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
public: public:
ParallelInterleaveIterator(const Params& params, bool deterministic) ParallelInterleaveIterator(const Params& params, bool deterministic)
: DatasetIterator<Dataset>(params), : DatasetIterator<Dataset>(params),
per_iterator_prefetch_(
static_cast<int>(params.dataset->block_length_ *
kPerIteratorPrefetchFactor) +
1),
future_elements_prefetch_(static_cast<int>(
params.dataset->cycle_length_ * kCyclePrefetchFactor)),
mu_(std::make_shared<mutex>()), mu_(std::make_shared<mutex>()),
num_parallel_calls_cond_var_(std::make_shared<condition_variable>()), num_parallel_calls_cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>( num_parallel_calls_(std::make_shared<model::SharedState>(
@ -257,7 +309,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// collection, `cycle_length_` threads for the current workers, and // collection, `cycle_length_` threads for the current workers, and
// `future_elements_prefetch_` for the future workers. // `future_elements_prefetch_` for the future workers.
int max_current_workers = dataset()->cycle_length_; int max_current_workers = dataset()->cycle_length_;
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_; int future_workers =
dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
int num_threads = 1 + max_current_workers + future_workers; int num_threads = 1 + max_current_workers + future_workers;
if (ctx->stats_aggregator()) { if (ctx->stats_aggregator()) {
num_threads++; num_threads++;
@ -656,7 +709,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// `cycle_length_` future workers to guarantee that whenever // `cycle_length_` future workers to guarantee that whenever
// `future_element_.size() < future_elements_prefetch_`, there will be a // `future_element_.size() < future_elements_prefetch_`, there will be a
// future worker available to create a new future element. // future worker available to create a new future element.
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_; int future_workers =
dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
{ {
mutex_lock l(*mu_); mutex_lock l(*mu_);
initial_current_workers = num_parallel_calls_->value; initial_current_workers = num_parallel_calls_->value;
@ -800,8 +854,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
current_workers_cond_var_.notify_one(); current_workers_cond_var_.notify_one();
} }
} }
while (!cancelled_ && while (!cancelled_ && (future_elements_.size() >=
(future_elements_.size() >= future_elements_prefetch_ || dataset()->prefetch_input_elements_ ||
wait_for_checkpoint_)) { wait_for_checkpoint_)) {
WaitWorkerThread(&future_workers_cond_var_, &l); WaitWorkerThread(&future_workers_cond_var_, &l);
} }
@ -860,7 +914,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
mutex_lock l(*mu_); mutex_lock l(*mu_);
element->results.push_back(std::move(result)); element->results.push_back(std::move(result));
NotifyElementUpdate(element); NotifyElementUpdate(element);
if (element->results.size() == per_iterator_prefetch_) { if (element->results.size() == dataset()->buffer_output_elements_) {
break; break;
} }
} }
@ -954,7 +1008,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
return true; return true;
} }
return element->iterator && return element->iterator &&
element->results.size() < per_iterator_prefetch_; element->results.size() < dataset()->buffer_output_elements_;
} }
inline void IncrementCurrentWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) { inline void IncrementCurrentWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
@ -1311,9 +1365,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// its elements are null. // its elements are null.
int64 last_valid_current_element_ GUARDED_BY(mu_) = -1; int64 last_valid_current_element_ GUARDED_BY(mu_) = -1;
const int per_iterator_prefetch_;
const int future_elements_prefetch_;
// Identifies whether the current_elements_ vector has been initialized. // Identifies whether the current_elements_ vector has been initialized.
bool initial_elements_created_ GUARDED_BY(mu_) = false; bool initial_elements_created_ GUARDED_BY(mu_) = false;
@ -1421,6 +1472,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
const std::unique_ptr<CapturedFunction> captured_func_; const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_; const int64 cycle_length_;
const int64 block_length_; const int64 block_length_;
const int64 buffer_output_elements_;
const int64 prefetch_input_elements_;
const int64 num_parallel_calls_; const int64 num_parallel_calls_;
const DeterminismPolicy deterministic_; const DeterminismPolicy deterministic_;
const DataTypeVector output_types_; const DataTypeVector output_types_;
@ -1431,7 +1484,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp( ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
OpKernelConstruction* ctx) OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx), op_version_(ctx->HasAttr(kSloppy) ? 2 : 3) { : UnaryDatasetOpKernel(ctx),
op_version_(OpVersionFromOpName(ctx->def().op())) {
FunctionMetadata::Params params; FunctionMetadata::Params params;
params.is_multi_device_function = true; params.is_multi_device_function = true;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
@ -1448,7 +1502,7 @@ ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault); deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
} }
} }
if (op_version_ == 3) { if (op_version_ >= 3) {
std::string deterministic; std::string deterministic;
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
OP_REQUIRES_OK( OP_REQUIRES_OK(
@ -1472,6 +1526,26 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
OP_REQUIRES(ctx, block_length > 0, OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0")); errors::InvalidArgument("`block_length` must be > 0"));
int64 buffer_output_elements = model::kAutotune;
int64 prefetch_input_elements = model::kAutotune;
if (op_version_ >= 4) {
OP_REQUIRES(ctx,
buffer_output_elements == model::kAutotune ||
buffer_output_elements >= 0,
errors::InvalidArgument("`buffer_output_elements` must be ",
model::kAutotune, " or >= 0 but is ",
buffer_output_elements));
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
&prefetch_input_elements));
OP_REQUIRES(ctx,
prefetch_input_elements == model::kAutotune ||
prefetch_input_elements >= 0,
errors::InvalidArgument("`prefetch_input_elements` must be ",
model::kAutotune, " or >= 0 but is ",
prefetch_input_elements));
}
int64 num_parallel_calls = 0; int64 num_parallel_calls = 0;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls)); ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
@ -1492,18 +1566,22 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
metrics::RecordTFDataAutotune(kDatasetType); metrics::RecordTFDataAutotune(kDatasetType);
} }
*output = new Dataset(ctx, input, std::move(captured_func), cycle_length, *output = new Dataset(
block_length, num_parallel_calls, deterministic_, ctx, input, std::move(captured_func), cycle_length, block_length,
output_types_, output_shapes_, op_version_); buffer_output_elements, prefetch_input_elements, num_parallel_calls,
deterministic_, output_types_, output_shapes_, op_version_);
} }
namespace { namespace {
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV2).Device(DEVICE_CPU),
ParallelInterleaveDatasetOp); ParallelInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV3").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV3).Device(DEVICE_CPU),
ParallelInterleaveDatasetOp); ParallelInterleaveDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV2"); REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV4).Device(DEVICE_CPU),
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV3"); ParallelInterleaveDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV2);
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV3);
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV4);
} // namespace } // namespace
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -29,6 +29,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
static constexpr const char* const kOtherArguments = "other_arguments"; static constexpr const char* const kOtherArguments = "other_arguments";
static constexpr const char* const kCycleLength = "cycle_length"; static constexpr const char* const kCycleLength = "cycle_length";
static constexpr const char* const kBlockLength = "block_length"; static constexpr const char* const kBlockLength = "block_length";
static constexpr const char* const kBufferOutputElements =
"buffer_output_elements";
static constexpr const char* const kPrefetchInputElements =
"prefetch_input_elements";
static constexpr const char* const kNumParallelCalls = "num_parallel_calls"; static constexpr const char* const kNumParallelCalls = "num_parallel_calls";
static constexpr const char* const kFunc = "f"; static constexpr const char* const kFunc = "f";
static constexpr const char* const kTarguments = "Targuments"; static constexpr const char* const kTarguments = "Targuments";

View File

@ -18,14 +18,15 @@ namespace data {
namespace { namespace {
constexpr char kNodeName[] = "parallel_interleave_dataset"; constexpr char kNodeName[] = "parallel_interleave_dataset";
constexpr int kOpVersion = 3; constexpr int kOpVersion = 4;
class ParallelInterleaveDatasetParams : public DatasetParams { class ParallelInterleaveDatasetParams : public DatasetParams {
public: public:
template <typename T> template <typename T>
ParallelInterleaveDatasetParams( ParallelInterleaveDatasetParams(
T input_dataset_params, std::vector<Tensor> other_arguments, T input_dataset_params, std::vector<Tensor> other_arguments,
int64 cycle_length, int64 block_length, int64 num_parallel_calls, int64 cycle_length, int64 block_length, int64 buffer_output_elements,
int64 prefetch_input_elements, int64 num_parallel_calls,
FunctionDefHelper::AttrValueWrapper func, FunctionDefHelper::AttrValueWrapper func,
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments, std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
const DataTypeVector& output_dtypes, const DataTypeVector& output_dtypes,
@ -36,6 +37,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
other_arguments_(std::move(other_arguments)), other_arguments_(std::move(other_arguments)),
cycle_length_(cycle_length), cycle_length_(cycle_length),
block_length_(block_length), block_length_(block_length),
buffer_output_elements_(buffer_output_elements),
prefetch_input_elements_(prefetch_input_elements),
num_parallel_calls_(num_parallel_calls), num_parallel_calls_(num_parallel_calls),
func_(std::move(func)), func_(std::move(func)),
func_lib_(std::move(func_lib)), func_lib_(std::move(func_lib)),
@ -56,6 +59,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
CreateTensor<int64>(TensorShape({}), {cycle_length_})); CreateTensor<int64>(TensorShape({}), {cycle_length_}));
input_tensors.emplace_back( input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {block_length_})); CreateTensor<int64>(TensorShape({}), {block_length_}));
input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {buffer_output_elements_}));
input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {prefetch_input_elements_}));
input_tensors.emplace_back( input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {num_parallel_calls_})); CreateTensor<int64>(TensorShape({}), {num_parallel_calls_}));
return input_tensors; return input_tensors;
@ -69,6 +76,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
} }
input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength); input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength);
input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength); input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength);
input_names->emplace_back(
ParallelInterleaveDatasetOp::kBufferOutputElements);
input_names->emplace_back(
ParallelInterleaveDatasetOp::kPrefetchInputElements);
input_names->emplace_back(ParallelInterleaveDatasetOp::kNumParallelCalls); input_names->emplace_back(ParallelInterleaveDatasetOp::kNumParallelCalls);
return Status::OK(); return Status::OK();
} }
@ -93,6 +104,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
std::vector<Tensor> other_arguments_; std::vector<Tensor> other_arguments_;
int64 cycle_length_; int64 cycle_length_;
int64 block_length_; int64 block_length_;
int64 buffer_output_elements_;
int64 prefetch_input_elements_;
int64 num_parallel_calls_; int64 num_parallel_calls_;
FunctionDefHelper::AttrValueWrapper func_; FunctionDefHelper::AttrValueWrapper func_;
std::vector<FunctionDef> func_lib_; std::vector<FunctionDef> func_lib_;
@ -111,8 +124,6 @@ FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc(
{"output_shapes", output_shapes}}); {"output_shapes", output_shapes}});
} }
// test case 1: cycle_length = 1, block_length = 1, num_parallel_calls = 1,
// sloppy = false
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1}, /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
@ -123,6 +134,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/1, /*cycle_length=*/1,
/*block_length=*/1, /*block_length=*/1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/1, /*num_parallel_calls=*/1,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -136,8 +149,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 2: cycle_length = 2, block_length = 1, num_parallel_calls = 2,
// sloppy = false
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1}, /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
@ -148,6 +159,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/2, /*cycle_length=*/2,
/*block_length=*/1, /*block_length=*/1,
/*buffer_output_elements=*/0,
/*prefetch_input_elements=*/0,
/*num_parallel_calls=*/2, /*num_parallel_calls=*/2,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -161,8 +174,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 3: cycle_length = 3, block_length = 1, num_parallel_calls = 2,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1}, /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
@ -173,6 +184,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/3, /*cycle_length=*/3,
/*block_length=*/1, /*block_length=*/1,
/*buffer_output_elements=*/0,
/*prefetch_input_elements=*/1,
/*num_parallel_calls=*/2, /*num_parallel_calls=*/2,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -186,9 +199,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 4: cycle_length = 5, block_length = 1, num_parallel_calls = 4,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1}, /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
@ -199,6 +209,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/5, /*cycle_length=*/5,
/*block_length=*/1, /*block_length=*/1,
/*buffer_output_elements=*/1,
/*prefetch_input_elements=*/0,
/*num_parallel_calls=*/4, /*num_parallel_calls=*/4,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -212,8 +224,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 5: cycle_length = 2, block_length = 2, num_parallel_calls = 1,
// sloppy = false
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>( /*components=*/{CreateTensor<tstring>(
@ -224,6 +234,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/2, /*cycle_length=*/2,
/*block_length=*/2, /*block_length=*/2,
/*buffer_output_elements=*/2,
/*prefetch_input_elements=*/2,
/*num_parallel_calls=*/1, /*num_parallel_calls=*/1,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -237,8 +249,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 6: cycle_length = 2, block_length = 3, num_parallel_calls = 2,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>( /*components=*/{CreateTensor<tstring>(
@ -249,6 +259,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/2, /*cycle_length=*/2,
/*block_length=*/3, /*block_length=*/3,
/*buffer_output_elements=*/100,
/*prefetch_input_elements=*/100,
/*num_parallel_calls=*/2, /*num_parallel_calls=*/2,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -262,8 +274,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 7: cycle_length = 3, block_length = 2, num_parallel_calls = 2,
// sloppy = false
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>( /*components=*/{CreateTensor<tstring>(
@ -274,6 +284,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/3, /*cycle_length=*/3,
/*block_length=*/2, /*block_length=*/2,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/2, /*num_parallel_calls=*/2,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -287,8 +299,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 8: cycle_length = 3, block_length = 3, num_parallel_calls = 3,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>( /*components=*/{CreateTensor<tstring>(
@ -299,6 +309,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/3, /*cycle_length=*/3,
/*block_length=*/3, /*block_length=*/3,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/3, /*num_parallel_calls=*/3,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -312,8 +324,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 9: cycle_length = 4, block_length = 4, num_parallel_calls = 4,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>( /*components=*/{CreateTensor<tstring>(
@ -324,6 +334,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/4, /*cycle_length=*/4,
/*block_length=*/4, /*block_length=*/4,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/4, /*num_parallel_calls=*/4,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -337,8 +349,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 10: cycle_length = 3, block_length = 3,
// num_parallel_calls = kAutotune, sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<tstring>( /*components=*/{CreateTensor<tstring>(
@ -349,6 +359,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/4, /*cycle_length=*/4,
/*block_length=*/4, /*block_length=*/4,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/model::kAutotune, /*num_parallel_calls=*/model::kAutotune,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -372,6 +384,8 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/11, /*cycle_length=*/11,
/*block_length=*/1, /*block_length=*/1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/2, /*num_parallel_calls=*/2,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -385,8 +399,6 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 11: cycle_length = 0, block_length = 1, num_parallel_calls = 2,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidCycleLength() { ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -398,6 +410,8 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/0, /*cycle_length=*/0,
/*block_length=*/1, /*block_length=*/1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/2, /*num_parallel_calls=*/2,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -411,8 +425,6 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 12: cycle_length = 1, block_length = -1, num_parallel_calls = 2,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidBlockLength() { ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -424,6 +436,8 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/1, /*cycle_length=*/1,
/*block_length=*/-1, /*block_length=*/-1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/2, /*num_parallel_calls=*/2,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -437,8 +451,6 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// test case 13: cycle_length = 1, block_length = 1, num_parallel_calls = -5,
// sloppy = true
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() { ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -450,6 +462,60 @@ ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/1, /*cycle_length=*/1,
/*block_length=*/1, /*block_length=*/1,
/*buffer_output_elements=*/model::kAutotune,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/-5,
/*func=*/
MakeTensorSliceDatasetFunc(
DataTypeVector({DT_INT64}),
std::vector<PartialTensorShape>({PartialTensorShape({1})})),
/*func_lib=*/{test::function::MakeTensorSliceDataset()},
/*type_arguments=*/{},
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({1})},
/*deterministic=*/DeterminismPolicy::kNondeterministic,
/*node_name=*/kNodeName);
}
ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8})},
/*node_name=*/"tensor_slice");
return ParallelInterleaveDatasetParams(
tensor_slice_dataset_params,
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/1,
/*buffer_output_elements=*/-2,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/-5,
/*func=*/
MakeTensorSliceDatasetFunc(
DataTypeVector({DT_INT64}),
std::vector<PartialTensorShape>({PartialTensorShape({1})})),
/*func_lib=*/{test::function::MakeTensorSliceDataset()},
/*type_arguments=*/{},
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({1})},
/*deterministic=*/DeterminismPolicy::kNondeterministic,
/*node_name=*/kNodeName);
}
ParallelInterleaveDatasetParams
ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8})},
/*node_name=*/"tensor_slice");
return ParallelInterleaveDatasetParams(
tensor_slice_dataset_params,
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/1,
/*buffer_output_elements=*/-2,
/*prefetch_input_elements=*/model::kAutotune,
/*num_parallel_calls=*/-5, /*num_parallel_calls=*/-5,
/*func=*/ /*func=*/
MakeTensorSliceDatasetFunc( MakeTensorSliceDatasetFunc(
@ -698,7 +764,10 @@ TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) {
std::vector<ParallelInterleaveDatasetParams> invalid_params = { std::vector<ParallelInterleaveDatasetParams> invalid_params = {
ParallelInterleaveDatasetParamsWithInvalidCycleLength(), ParallelInterleaveDatasetParamsWithInvalidCycleLength(),
ParallelInterleaveDatasetParamsWithInvalidBlockLength(), ParallelInterleaveDatasetParamsWithInvalidBlockLength(),
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls()}; ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls(),
ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements(),
ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements(),
};
for (auto& dataset_params : invalid_params) { for (auto& dataset_params : invalid_params) {
EXPECT_EQ(Initialize(dataset_params).code(), EXPECT_EQ(Initialize(dataset_params).code(),
tensorflow::error::INVALID_ARGUMENT); tensorflow::error::INVALID_ARGUMENT);

View File

@ -227,6 +227,24 @@ REGISTER_OP("ParallelInterleaveDatasetV3")
.Attr("output_shapes: list(shape) >= 1") .Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape); .SetShapeFn(shape_inference::ScalarShape);
// Like V3, but adds buffer_output_elements and prefetch_input_elements.
REGISTER_OP("ParallelInterleaveDatasetV4")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Input("cycle_length: int64")
.Input("block_length: int64")
.Input("buffer_output_elements: int64")
.Input("prefetch_input_elements: int64")
.Input("num_parallel_calls: int64")
.Output("handle: variant")
.Attr("f: func")
// "true", "false", or "default".
.Attr("deterministic: string = 'default'")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("FilterDataset") REGISTER_OP("FilterDataset")
.Input("input_dataset: variant") .Input("input_dataset: variant")
.Input("other_arguments: Targuments") .Input("other_arguments: Targuments")

View File

@ -64,6 +64,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
parallel_interleave = "ParallelInterleaveV2" parallel_interleave = "ParallelInterleaveV2"
if compat.forward_compatible(2020, 2, 20): if compat.forward_compatible(2020, 2, 20):
parallel_interleave = "ParallelInterleaveV3" parallel_interleave = "ParallelInterleaveV3"
if compat.forward_compatible(2020, 3, 6):
parallel_interleave = "ParallelInterleaveV4"
dataset = dataset.apply( dataset = dataset.apply(
testing.assert_next([parallel_interleave, "Prefetch", "FiniteTake"])) testing.assert_next([parallel_interleave, "Prefetch", "FiniteTake"]))
dataset = dataset.interleave( dataset = dataset.interleave(
@ -79,6 +81,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
parallel_interleave = "ParallelInterleaveV2" parallel_interleave = "ParallelInterleaveV2"
if compat.forward_compatible(2020, 2, 20): if compat.forward_compatible(2020, 2, 20):
parallel_interleave = "ParallelInterleaveV3" parallel_interleave = "ParallelInterleaveV3"
if compat.forward_compatible(2020, 3, 6):
parallel_interleave = "ParallelInterleaveV4"
dataset = dataset.apply( dataset = dataset.apply(
testing.assert_next([ testing.assert_next([
"ParallelMap", "Prefetch", parallel_interleave, "Prefetch", "ParallelMap", "Prefetch", parallel_interleave, "Prefetch",

View File

@ -1714,9 +1714,13 @@ name=None))
if num_parallel_calls is None: if num_parallel_calls is None:
return InterleaveDataset(self, map_func, cycle_length, block_length) return InterleaveDataset(self, map_func, cycle_length, block_length)
else: else:
return ParallelInterleaveDataset(self, map_func, cycle_length, return ParallelInterleaveDataset(
block_length, num_parallel_calls, self,
deterministic) map_func,
cycle_length,
block_length,
num_parallel_calls,
deterministic=deterministic)
def filter(self, predicate): def filter(self, predicate):
"""Filters this dataset according to `predicate`. """Filters this dataset according to `predicate`.
@ -4042,6 +4046,8 @@ class ParallelInterleaveDataset(UnaryDataset):
cycle_length, cycle_length,
block_length, block_length,
num_parallel_calls, num_parallel_calls,
buffer_output_elements=AUTOTUNE,
prefetch_input_elements=AUTOTUNE,
deterministic=None): deterministic=None):
"""See `Dataset.interleave()` for details.""" """See `Dataset.interleave()` for details."""
self._input_dataset = input_dataset self._input_dataset = input_dataset
@ -4056,6 +4062,15 @@ class ParallelInterleaveDataset(UnaryDataset):
cycle_length, dtype=dtypes.int64, name="cycle_length") cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor( self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length") block_length, dtype=dtypes.int64, name="block_length")
self._buffer_output_elements = ops.convert_to_tensor(
buffer_output_elements,
dtype=dtypes.int64,
name="buffer_output_elements")
self._prefetch_input_elements = ops.convert_to_tensor(
prefetch_input_elements,
dtype=dtypes.int64,
name="prefetch_input_elements")
self._num_parallel_calls = ops.convert_to_tensor( self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
if deterministic is None: if deterministic is None:
@ -4065,7 +4080,21 @@ class ParallelInterleaveDataset(UnaryDataset):
else: else:
deterministic_string = "false" deterministic_string = "false"
if deterministic is not None or compat.forward_compatible(2020, 2, 20): if (buffer_output_elements != AUTOTUNE or
prefetch_input_elements != AUTOTUNE or
compat.forward_compatible(2020, 3, 6)):
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs, # pylint: disable=protected-access
self._cycle_length,
self._block_length,
self._buffer_output_elements,
self._prefetch_input_elements,
self._num_parallel_calls,
f=self._map_func.function,
deterministic=deterministic_string,
**self._flat_structure)
elif deterministic is not None or compat.forward_compatible(2020, 2, 20):
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3( variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
input_dataset._variant_tensor, # pylint: disable=protected-access input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs, # pylint: disable=protected-access self._map_func.function.captured_inputs, # pylint: disable=protected-access

View File

@ -2636,6 +2636,10 @@ tf_module {
name: "ParallelInterleaveDatasetV3" name: "ParallelInterleaveDatasetV3"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], " argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
} }
member_method {
name: "ParallelInterleaveDatasetV4"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method { member_method {
name: "ParallelMapDataset" name: "ParallelMapDataset"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], " argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "

View File

@ -2636,6 +2636,10 @@ tf_module {
name: "ParallelInterleaveDatasetV3" name: "ParallelInterleaveDatasetV3"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], " argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
} }
member_method {
name: "ParallelInterleaveDatasetV4"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method { member_method {
name: "ParallelMapDataset" name: "ParallelMapDataset"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], " argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "