Add API to control parallel interleave prefetching.
PiperOrigin-RevId: 294766661 Change-Id: I8061629522d19d408cd8b7a1981836a4ee958110
This commit is contained in:
parent
a7f1d52b03
commit
0c1ca5c674
@ -0,0 +1,100 @@
|
||||
op {
|
||||
graph_op_name: "ParallelInterleaveDatasetV4"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
Dataset that produces a stream of arguments for the function `f`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "other_arguments"
|
||||
description: <<END
|
||||
Additional arguments to pass to `f` beyond those produced by `input_dataset`.
|
||||
Evaluated once when the dataset is instantiated.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "cycle_length"
|
||||
description: <<END
|
||||
Number of datasets (each created by applying `f` to the elements of
|
||||
`input_dataset`) among which the `ParallelInterleaveDatasetV2` will cycle in a
|
||||
round-robin fashion.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "block_length"
|
||||
description: <<END
|
||||
Number of elements at a time to produce from each interleaved invocation of a
|
||||
dataset returned by `f`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "buffer_output_elements"
|
||||
description: <<END
|
||||
The number of elements each iterator being interleaved should buffer (similar
|
||||
to the `.prefetch()` transformation for each interleaved iterator).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "prefetch_input_elements"
|
||||
description: <<END
|
||||
Determines the number of iterators to prefetch, allowing buffers to warm up and
|
||||
data to be pre-fetched without blocking the main thread.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_parallel_calls"
|
||||
description: <<END
|
||||
Determines the number of threads that should be used for fetching data from
|
||||
input datasets in parallel. The Python API `tf.data.experimental.AUTOTUNE`
|
||||
constant can be used to indicate that the level of parallelism should be autotuned.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
description: <<END
|
||||
A function mapping elements of `input_dataset`, concatenated with
|
||||
`other_arguments`, to a Dataset variant that contains elements matching
|
||||
`output_types` and `output_shapes`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "deterministic"
|
||||
description: <<END
|
||||
A string indicating the op-level determinism to use. Deterministic controls
|
||||
whether the interleave is allowed to return elements out of order if the next
|
||||
element to be returned isn't available, but a later element is. Options are
|
||||
"true", "false", and "default". "default" indicates that determinism should be
|
||||
decided by the `experimental_deterministic` parameter of `tf.data.Options`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
description: <<END
|
||||
Types of the elements of `other_arguments`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
}
|
||||
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
|
||||
description: <<END
|
||||
The resulting dataset is similar to the `InterleaveDataset`, except that the
|
||||
dataset will fetch records from the interleaved datasets in parallel.
|
||||
|
||||
The `tf.data` Python API creates instances of this op from
|
||||
`Dataset.interleave()` when the `num_parallel_calls` parameter of that method
|
||||
is set to any value other than `None`.
|
||||
|
||||
By default, the output of this dataset will be deterministic, which may result
|
||||
in the dataset blocking if the next data item to be returned isn't available.
|
||||
In order to avoid head-of-line blocking, one can either set the `deterministic`
|
||||
attribute to "false", or leave it as "default" and set the
|
||||
`experimental_deterministic` parameter of `tf.data.Options` to `False`.
|
||||
This can improve performance at the expense of non-determinism.
|
||||
END
|
||||
}
|
@ -89,13 +89,14 @@ constexpr std::array<const char*, 28> kPassThroughOps = {
|
||||
};
|
||||
|
||||
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
|
||||
constexpr std::array<const char*, 6> kFuncDatasetOps = {
|
||||
constexpr std::array<const char*, 7> kFuncDatasetOps = {
|
||||
"ExperimentalParallelInterleaveDataset",
|
||||
"FlatMapDataset",
|
||||
"InterleaveDataset",
|
||||
"ParallelInterleaveDataset",
|
||||
"ParallelInterleaveDatasetV2",
|
||||
"ParallelInterleaveDatasetV3"
|
||||
"ParallelInterleaveDatasetV3",
|
||||
"ParallelInterleaveDatasetV4"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
|
||||
|
@ -33,12 +33,10 @@ namespace {
|
||||
constexpr char kLegacyAutotune[] = "legacy_autotune";
|
||||
constexpr char kPrefetchDataset[] = "PrefetchDataset";
|
||||
|
||||
constexpr std::array<const char*, 5> kAsyncDatasetOps = {
|
||||
"ExperimentalMapAndBatchDataset",
|
||||
"ParallelMapDataset",
|
||||
"ParallelInterleaveDatasetV2",
|
||||
"ParallelInterleaveDatasetV3",
|
||||
"MapAndBatchDataset",
|
||||
constexpr std::array<const char*, 6> kAsyncDatasetOps = {
|
||||
"ExperimentalMapAndBatchDataset", "ParallelMapDataset",
|
||||
"ParallelInterleaveDatasetV2", "ParallelInterleaveDatasetV3",
|
||||
"ParallelInterleaveDatasetV4", "MapAndBatchDataset",
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@ -32,8 +32,9 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
|
||||
"ParseExampleDataset",
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 1> kDeterministicAttrOps = {
|
||||
constexpr std::array<const char*, 2> kDeterministicAttrOps = {
|
||||
"ParallelInterleaveDatasetV3",
|
||||
"ParallelInterleaveDatasetV4",
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
|
@ -59,6 +59,10 @@ namespace data {
|
||||
ParallelInterleaveDatasetOp::kCycleLength;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelInterleaveDatasetOp::kBlockLength;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelInterleaveDatasetOp::kBufferOutputElements;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelInterleaveDatasetOp::kPrefetchInputElements;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelInterleaveDatasetOp::kNumParallelCalls;
|
||||
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
|
||||
@ -91,34 +95,64 @@ constexpr char kSizeSuffix[] = ".size";
|
||||
constexpr char kInputsSuffix[] = ".inputs";
|
||||
constexpr char kIsReadySuffix[] = ".is_ready";
|
||||
|
||||
// `kCyclePrefetchFactor * cycle_length` is the number of future cycle elements
|
||||
// that will be prefetched ahead of time. The purpose of prefetching future
|
||||
// cycle elements is to overlap expensive initialization (e.g. opening of a
|
||||
// remote file) with other computation.
|
||||
constexpr double kCyclePrefetchFactor = 2.0L;
|
||||
constexpr char kParallelInterleaveDatasetV2[] = "ParallelInterleaveDatasetV2";
|
||||
constexpr char kParallelInterleaveDatasetV3[] = "ParallelInterleaveDatasetV3";
|
||||
constexpr char kParallelInterleaveDatasetV4[] = "ParallelInterleaveDatasetV4";
|
||||
|
||||
// `kPerIteratorPrefetchFactor * block_length + 1` is the number of per-iterator
|
||||
// results that will be prefetched ahead of time. The `+ 1` is to match the
|
||||
// behavior of the original autotune implementation.
|
||||
constexpr double kPerIteratorPrefetchFactor = 2.0L;
|
||||
// `kCyclePrefetchFactor * cycle_length` is the default number of future cycle
|
||||
// elements that will be prefetched ahead of time. The purpose of prefetching
|
||||
// future cycle elements is to overlap expensive initialization (e.g. opening of
|
||||
// a remote file) with other computation.
|
||||
constexpr double kDefaultCyclePrefetchFactor = 2.0L;
|
||||
|
||||
// `kPerIteratorPrefetchFactor * block_length + 1` is the defualt number of
|
||||
// per-iterator results that will be prefetched ahead of time. The `+ 1` is to
|
||||
// match the behavior of the original implementation.
|
||||
constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L;
|
||||
|
||||
// Period between reporting dataset statistics.
|
||||
constexpr int kStatsReportingPeriodMillis = 1000;
|
||||
|
||||
namespace {
|
||||
int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements,
|
||||
int64 block_length) {
|
||||
if (configured_buffer_output_elements != model::kAutotune) {
|
||||
return configured_buffer_output_elements;
|
||||
}
|
||||
return kDefaultPerIteratorPrefetchFactor * block_length + 1;
|
||||
}
|
||||
|
||||
int64 ComputePrefetchInputElements(int64 configured_prefetch_input_elements,
|
||||
int64 cycle_length) {
|
||||
if (configured_prefetch_input_elements != model::kAutotune) {
|
||||
return configured_prefetch_input_elements;
|
||||
}
|
||||
return kDefaultCyclePrefetchFactor * cycle_length;
|
||||
}
|
||||
|
||||
int64 OpVersionFromOpName(absl::string_view op_name) {
|
||||
if (op_name == kParallelInterleaveDatasetV2) {
|
||||
return 2;
|
||||
} else if (op_name == kParallelInterleaveDatasetV3) {
|
||||
return 3;
|
||||
} else {
|
||||
DCHECK_EQ(op_name, kParallelInterleaveDatasetV4);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// The motivation for creating an alternative implementation of parallel
|
||||
// interleave is to decouple the degree of parallelism from the cycle length.
|
||||
// This makes it possible to change the degree of parallelism (e.g. through
|
||||
// auto-tuning) without changing the cycle length (which would change the order
|
||||
// in which elements are produced).
|
||||
//
|
||||
// Furthermore, this class favors modularity over extended functionality. In
|
||||
// particular, it refrains from implementing configurable buffering of output
|
||||
// elements and prefetching of input iterators.
|
||||
class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
||||
int64 block_length, int64 num_parallel_calls,
|
||||
int64 block_length, int64 buffer_output_elements,
|
||||
int64 prefetch_input_elements, int64 num_parallel_calls,
|
||||
DeterminismPolicy deterministic, const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes, int op_version)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
@ -126,6 +160,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
captured_func_(std::move(captured_func)),
|
||||
cycle_length_(cycle_length),
|
||||
block_length_(block_length),
|
||||
buffer_output_elements_(
|
||||
ComputeBufferOutputElements(buffer_output_elements, block_length)),
|
||||
prefetch_input_elements_(ComputePrefetchInputElements(
|
||||
prefetch_input_elements, cycle_length)),
|
||||
num_parallel_calls_(num_parallel_calls),
|
||||
deterministic_(deterministic),
|
||||
output_types_(output_types),
|
||||
@ -179,19 +217,44 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
std::vector<std::pair<size_t, Node*>> inputs;
|
||||
std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
|
||||
int input_index = 0;
|
||||
|
||||
Node* input_node;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
|
||||
Node* cycle_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
|
||||
Node* block_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
|
||||
Node* num_parallel_calls_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
|
||||
inputs.emplace_back(input_index++, input_node);
|
||||
|
||||
std::vector<Node*> other_arguments;
|
||||
DataTypeVector other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
|
||||
&other_arguments_types));
|
||||
list_inputs.emplace_back(input_index++, other_arguments);
|
||||
|
||||
Node* cycle_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
|
||||
inputs.emplace_back(input_index++, cycle_length_node);
|
||||
|
||||
Node* block_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
|
||||
inputs.emplace_back(input_index++, block_length_node);
|
||||
|
||||
if (op_version_ >= 4) {
|
||||
Node* buffer_output_elements_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
|
||||
inputs.emplace_back(input_index++, buffer_output_elements_node);
|
||||
|
||||
Node* prefetch_input_elements_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
|
||||
&prefetch_input_elements_node));
|
||||
inputs.emplace_back(input_index++, prefetch_input_elements_node);
|
||||
}
|
||||
|
||||
Node* num_parallel_calls_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
|
||||
inputs.emplace_back(input_index++, num_parallel_calls_node);
|
||||
|
||||
std::vector<std::pair<StringPiece, AttrValue>> attrs;
|
||||
AttrValue f;
|
||||
@ -207,18 +270,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
|
||||
attrs.emplace_back(kSloppy, sloppy_attr);
|
||||
}
|
||||
if (op_version_ == 3) {
|
||||
if (op_version_ >= 3) {
|
||||
AttrValue deterministic_attr;
|
||||
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
|
||||
attrs.emplace_back(kDeterministic, deterministic_attr);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this,
|
||||
{{0, input_node},
|
||||
{2, cycle_length_node},
|
||||
{3, block_length_node},
|
||||
{4, num_parallel_calls_node}},
|
||||
{{1, other_arguments}}, attrs, output));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -227,12 +285,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
ParallelInterleaveIterator(const Params& params, bool deterministic)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
per_iterator_prefetch_(
|
||||
static_cast<int>(params.dataset->block_length_ *
|
||||
kPerIteratorPrefetchFactor) +
|
||||
1),
|
||||
future_elements_prefetch_(static_cast<int>(
|
||||
params.dataset->cycle_length_ * kCyclePrefetchFactor)),
|
||||
mu_(std::make_shared<mutex>()),
|
||||
num_parallel_calls_cond_var_(std::make_shared<condition_variable>()),
|
||||
num_parallel_calls_(std::make_shared<model::SharedState>(
|
||||
@ -257,7 +309,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// collection, `cycle_length_` threads for the current workers, and
|
||||
// `future_elements_prefetch_` for the future workers.
|
||||
int max_current_workers = dataset()->cycle_length_;
|
||||
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
|
||||
int future_workers =
|
||||
dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
|
||||
int num_threads = 1 + max_current_workers + future_workers;
|
||||
if (ctx->stats_aggregator()) {
|
||||
num_threads++;
|
||||
@ -656,7 +709,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// `cycle_length_` future workers to guarantee that whenever
|
||||
// `future_element_.size() < future_elements_prefetch_`, there will be a
|
||||
// future worker available to create a new future element.
|
||||
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
|
||||
int future_workers =
|
||||
dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
initial_current_workers = num_parallel_calls_->value;
|
||||
@ -800,9 +854,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
current_workers_cond_var_.notify_one();
|
||||
}
|
||||
}
|
||||
while (!cancelled_ &&
|
||||
(future_elements_.size() >= future_elements_prefetch_ ||
|
||||
wait_for_checkpoint_)) {
|
||||
while (!cancelled_ && (future_elements_.size() >=
|
||||
dataset()->prefetch_input_elements_ ||
|
||||
wait_for_checkpoint_)) {
|
||||
WaitWorkerThread(&future_workers_cond_var_, &l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
@ -860,7 +914,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
mutex_lock l(*mu_);
|
||||
element->results.push_back(std::move(result));
|
||||
NotifyElementUpdate(element);
|
||||
if (element->results.size() == per_iterator_prefetch_) {
|
||||
if (element->results.size() == dataset()->buffer_output_elements_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -954,7 +1008,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return true;
|
||||
}
|
||||
return element->iterator &&
|
||||
element->results.size() < per_iterator_prefetch_;
|
||||
element->results.size() < dataset()->buffer_output_elements_;
|
||||
}
|
||||
|
||||
inline void IncrementCurrentWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
@ -1311,9 +1365,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// its elements are null.
|
||||
int64 last_valid_current_element_ GUARDED_BY(mu_) = -1;
|
||||
|
||||
const int per_iterator_prefetch_;
|
||||
const int future_elements_prefetch_;
|
||||
|
||||
// Identifies whether the current_elements_ vector has been initialized.
|
||||
bool initial_elements_created_ GUARDED_BY(mu_) = false;
|
||||
|
||||
@ -1421,6 +1472,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||
const int64 cycle_length_;
|
||||
const int64 block_length_;
|
||||
const int64 buffer_output_elements_;
|
||||
const int64 prefetch_input_elements_;
|
||||
const int64 num_parallel_calls_;
|
||||
const DeterminismPolicy deterministic_;
|
||||
const DataTypeVector output_types_;
|
||||
@ -1431,7 +1484,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
|
||||
OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx), op_version_(ctx->HasAttr(kSloppy) ? 2 : 3) {
|
||||
: UnaryDatasetOpKernel(ctx),
|
||||
op_version_(OpVersionFromOpName(ctx->def().op())) {
|
||||
FunctionMetadata::Params params;
|
||||
params.is_multi_device_function = true;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
@ -1448,7 +1502,7 @@ ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
|
||||
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
|
||||
}
|
||||
}
|
||||
if (op_version_ == 3) {
|
||||
if (op_version_ >= 3) {
|
||||
std::string deterministic;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
|
||||
OP_REQUIRES_OK(
|
||||
@ -1472,6 +1526,26 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
OP_REQUIRES(ctx, block_length > 0,
|
||||
errors::InvalidArgument("`block_length` must be > 0"));
|
||||
|
||||
int64 buffer_output_elements = model::kAutotune;
|
||||
int64 prefetch_input_elements = model::kAutotune;
|
||||
if (op_version_ >= 4) {
|
||||
OP_REQUIRES(ctx,
|
||||
buffer_output_elements == model::kAutotune ||
|
||||
buffer_output_elements >= 0,
|
||||
errors::InvalidArgument("`buffer_output_elements` must be ",
|
||||
model::kAutotune, " or >= 0 but is ",
|
||||
buffer_output_elements));
|
||||
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
|
||||
&prefetch_input_elements));
|
||||
OP_REQUIRES(ctx,
|
||||
prefetch_input_elements == model::kAutotune ||
|
||||
prefetch_input_elements >= 0,
|
||||
errors::InvalidArgument("`prefetch_input_elements` must be ",
|
||||
model::kAutotune, " or >= 0 but is ",
|
||||
prefetch_input_elements));
|
||||
}
|
||||
|
||||
int64 num_parallel_calls = 0;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
|
||||
@ -1492,18 +1566,22 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
metrics::RecordTFDataAutotune(kDatasetType);
|
||||
}
|
||||
|
||||
*output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
|
||||
block_length, num_parallel_calls, deterministic_,
|
||||
output_types_, output_shapes_, op_version_);
|
||||
*output = new Dataset(
|
||||
ctx, input, std::move(captured_func), cycle_length, block_length,
|
||||
buffer_output_elements, prefetch_input_elements, num_parallel_calls,
|
||||
deterministic_, output_types_, output_shapes_, op_version_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV2).Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV3").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV3).Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV2");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV3");
|
||||
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV4).Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV2);
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV3);
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV4);
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -29,6 +29,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kOtherArguments = "other_arguments";
|
||||
static constexpr const char* const kCycleLength = "cycle_length";
|
||||
static constexpr const char* const kBlockLength = "block_length";
|
||||
static constexpr const char* const kBufferOutputElements =
|
||||
"buffer_output_elements";
|
||||
static constexpr const char* const kPrefetchInputElements =
|
||||
"prefetch_input_elements";
|
||||
static constexpr const char* const kNumParallelCalls = "num_parallel_calls";
|
||||
static constexpr const char* const kFunc = "f";
|
||||
static constexpr const char* const kTarguments = "Targuments";
|
||||
|
@ -18,14 +18,15 @@ namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "parallel_interleave_dataset";
|
||||
constexpr int kOpVersion = 3;
|
||||
constexpr int kOpVersion = 4;
|
||||
|
||||
class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
public:
|
||||
template <typename T>
|
||||
ParallelInterleaveDatasetParams(
|
||||
T input_dataset_params, std::vector<Tensor> other_arguments,
|
||||
int64 cycle_length, int64 block_length, int64 num_parallel_calls,
|
||||
int64 cycle_length, int64 block_length, int64 buffer_output_elements,
|
||||
int64 prefetch_input_elements, int64 num_parallel_calls,
|
||||
FunctionDefHelper::AttrValueWrapper func,
|
||||
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
|
||||
const DataTypeVector& output_dtypes,
|
||||
@ -36,6 +37,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
other_arguments_(std::move(other_arguments)),
|
||||
cycle_length_(cycle_length),
|
||||
block_length_(block_length),
|
||||
buffer_output_elements_(buffer_output_elements),
|
||||
prefetch_input_elements_(prefetch_input_elements),
|
||||
num_parallel_calls_(num_parallel_calls),
|
||||
func_(std::move(func)),
|
||||
func_lib_(std::move(func_lib)),
|
||||
@ -56,6 +59,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
CreateTensor<int64>(TensorShape({}), {cycle_length_}));
|
||||
input_tensors.emplace_back(
|
||||
CreateTensor<int64>(TensorShape({}), {block_length_}));
|
||||
input_tensors.emplace_back(
|
||||
CreateTensor<int64>(TensorShape({}), {buffer_output_elements_}));
|
||||
input_tensors.emplace_back(
|
||||
CreateTensor<int64>(TensorShape({}), {prefetch_input_elements_}));
|
||||
input_tensors.emplace_back(
|
||||
CreateTensor<int64>(TensorShape({}), {num_parallel_calls_}));
|
||||
return input_tensors;
|
||||
@ -69,6 +76,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
}
|
||||
input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength);
|
||||
input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength);
|
||||
input_names->emplace_back(
|
||||
ParallelInterleaveDatasetOp::kBufferOutputElements);
|
||||
input_names->emplace_back(
|
||||
ParallelInterleaveDatasetOp::kPrefetchInputElements);
|
||||
input_names->emplace_back(ParallelInterleaveDatasetOp::kNumParallelCalls);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -93,6 +104,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
std::vector<Tensor> other_arguments_;
|
||||
int64 cycle_length_;
|
||||
int64 block_length_;
|
||||
int64 buffer_output_elements_;
|
||||
int64 prefetch_input_elements_;
|
||||
int64 num_parallel_calls_;
|
||||
FunctionDefHelper::AttrValueWrapper func_;
|
||||
std::vector<FunctionDef> func_lib_;
|
||||
@ -111,8 +124,6 @@ FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc(
|
||||
{"output_shapes", output_shapes}});
|
||||
}
|
||||
|
||||
// test case 1: cycle_length = 1, block_length = 1, num_parallel_calls = 1,
|
||||
// sloppy = false
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||
@ -123,6 +134,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/1,
|
||||
/*block_length=*/1,
|
||||
/*buffer_output_elements=*/model::kAutotune,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/1,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -136,8 +149,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 2: cycle_length = 2, block_length = 1, num_parallel_calls = 2,
|
||||
// sloppy = false
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||
@ -148,6 +159,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/2,
|
||||
/*block_length=*/1,
|
||||
/*buffer_output_elements=*/0,
|
||||
/*prefetch_input_elements=*/0,
|
||||
/*num_parallel_calls=*/2,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -161,8 +174,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 3: cycle_length = 3, block_length = 1, num_parallel_calls = 2,
|
||||
// sloppy = true
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||
@ -173,6 +184,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/3,
|
||||
/*block_length=*/1,
|
||||
/*buffer_output_elements=*/0,
|
||||
/*prefetch_input_elements=*/1,
|
||||
/*num_parallel_calls=*/2,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -186,9 +199,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 4: cycle_length = 5, block_length = 1, num_parallel_calls = 4,
|
||||
// sloppy = true
|
||||
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||
@ -199,6 +209,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/5,
|
||||
/*block_length=*/1,
|
||||
/*buffer_output_elements=*/1,
|
||||
/*prefetch_input_elements=*/0,
|
||||
/*num_parallel_calls=*/4,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -212,8 +224,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 5: cycle_length = 2, block_length = 2, num_parallel_calls = 1,
|
||||
// sloppy = false
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<tstring>(
|
||||
@ -224,6 +234,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/2,
|
||||
/*block_length=*/2,
|
||||
/*buffer_output_elements=*/2,
|
||||
/*prefetch_input_elements=*/2,
|
||||
/*num_parallel_calls=*/1,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -237,8 +249,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 6: cycle_length = 2, block_length = 3, num_parallel_calls = 2,
|
||||
// sloppy = true
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<tstring>(
|
||||
@ -249,6 +259,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/2,
|
||||
/*block_length=*/3,
|
||||
/*buffer_output_elements=*/100,
|
||||
/*prefetch_input_elements=*/100,
|
||||
/*num_parallel_calls=*/2,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -262,8 +274,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 7: cycle_length = 3, block_length = 2, num_parallel_calls = 2,
|
||||
// sloppy = false
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<tstring>(
|
||||
@ -274,6 +284,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/3,
|
||||
/*block_length=*/2,
|
||||
/*buffer_output_elements=*/model::kAutotune,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/2,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -287,8 +299,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 8: cycle_length = 3, block_length = 3, num_parallel_calls = 3,
|
||||
// sloppy = true
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<tstring>(
|
||||
@ -299,6 +309,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/3,
|
||||
/*block_length=*/3,
|
||||
/*buffer_output_elements=*/model::kAutotune,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/3,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -312,8 +324,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 9: cycle_length = 4, block_length = 4, num_parallel_calls = 4,
|
||||
// sloppy = true
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<tstring>(
|
||||
@ -324,6 +334,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/4,
|
||||
/*block_length=*/4,
|
||||
/*buffer_output_elements=*/model::kAutotune,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/4,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -337,8 +349,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 10: cycle_length = 3, block_length = 3,
|
||||
// num_parallel_calls = kAutotune, sloppy = true
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<tstring>(
|
||||
@ -349,6 +359,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/4,
|
||||
/*block_length=*/4,
|
||||
/*buffer_output_elements=*/model::kAutotune,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/model::kAutotune,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -372,6 +384,8 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/11,
|
||||
/*block_length=*/1,
|
||||
/*buffer_output_elements=*/model::kAutotune,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/2,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -385,8 +399,6 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 11: cycle_length = 0, block_length = 1, num_parallel_calls = 2,
|
||||
// sloppy = true
|
||||
ParallelInterleaveDatasetParams
|
||||
ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
@ -398,6 +410,8 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/0,
|
||||
/*block_length=*/1,
|
||||
/*buffer_output_elements=*/model::kAutotune,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/2,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -411,8 +425,6 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 12: cycle_length = 1, block_length = -1, num_parallel_calls = 2,
|
||||
// sloppy = true
|
||||
ParallelInterleaveDatasetParams
|
||||
ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
@ -424,6 +436,8 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/1,
|
||||
/*block_length=*/-1,
|
||||
/*buffer_output_elements=*/model::kAutotune,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/2,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -437,8 +451,6 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 13: cycle_length = 1, block_length = 1, num_parallel_calls = -5,
|
||||
// sloppy = true
|
||||
ParallelInterleaveDatasetParams
|
||||
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
@ -450,6 +462,60 @@ ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/1,
|
||||
/*block_length=*/1,
|
||||
/*buffer_output_elements=*/model::kAutotune,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/-5,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
DataTypeVector({DT_INT64}),
|
||||
std::vector<PartialTensorShape>({PartialTensorShape({1})})),
|
||||
/*func_lib=*/{test::function::MakeTensorSliceDataset()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
ParallelInterleaveDatasetParams
|
||||
ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8})},
|
||||
/*node_name=*/"tensor_slice");
|
||||
return ParallelInterleaveDatasetParams(
|
||||
tensor_slice_dataset_params,
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/1,
|
||||
/*block_length=*/1,
|
||||
/*buffer_output_elements=*/-2,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/-5,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
DataTypeVector({DT_INT64}),
|
||||
std::vector<PartialTensorShape>({PartialTensorShape({1})})),
|
||||
/*func_lib=*/{test::function::MakeTensorSliceDataset()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
ParallelInterleaveDatasetParams
|
||||
ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8})},
|
||||
/*node_name=*/"tensor_slice");
|
||||
return ParallelInterleaveDatasetParams(
|
||||
tensor_slice_dataset_params,
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/1,
|
||||
/*block_length=*/1,
|
||||
/*buffer_output_elements=*/-2,
|
||||
/*prefetch_input_elements=*/model::kAutotune,
|
||||
/*num_parallel_calls=*/-5,
|
||||
/*func=*/
|
||||
MakeTensorSliceDatasetFunc(
|
||||
@ -698,7 +764,10 @@ TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) {
|
||||
std::vector<ParallelInterleaveDatasetParams> invalid_params = {
|
||||
ParallelInterleaveDatasetParamsWithInvalidCycleLength(),
|
||||
ParallelInterleaveDatasetParamsWithInvalidBlockLength(),
|
||||
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls()};
|
||||
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls(),
|
||||
ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements(),
|
||||
ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements(),
|
||||
};
|
||||
for (auto& dataset_params : invalid_params) {
|
||||
EXPECT_EQ(Initialize(dataset_params).code(),
|
||||
tensorflow::error::INVALID_ARGUMENT);
|
||||
|
@ -227,6 +227,24 @@ REGISTER_OP("ParallelInterleaveDatasetV3")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
// Like V3, but adds buffer_output_elements and prefetch_input_elements.
|
||||
REGISTER_OP("ParallelInterleaveDatasetV4")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Input("cycle_length: int64")
|
||||
.Input("block_length: int64")
|
||||
.Input("buffer_output_elements: int64")
|
||||
.Input("prefetch_input_elements: int64")
|
||||
.Input("num_parallel_calls: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
// "true", "false", or "default".
|
||||
.Attr("deterministic: string = 'default'")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("FilterDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
|
@ -64,6 +64,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
parallel_interleave = "ParallelInterleaveV2"
|
||||
if compat.forward_compatible(2020, 2, 20):
|
||||
parallel_interleave = "ParallelInterleaveV3"
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
parallel_interleave = "ParallelInterleaveV4"
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next([parallel_interleave, "Prefetch", "FiniteTake"]))
|
||||
dataset = dataset.interleave(
|
||||
@ -79,6 +81,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
parallel_interleave = "ParallelInterleaveV2"
|
||||
if compat.forward_compatible(2020, 2, 20):
|
||||
parallel_interleave = "ParallelInterleaveV3"
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
parallel_interleave = "ParallelInterleaveV4"
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next([
|
||||
"ParallelMap", "Prefetch", parallel_interleave, "Prefetch",
|
||||
|
@ -1714,9 +1714,13 @@ name=None))
|
||||
if num_parallel_calls is None:
|
||||
return InterleaveDataset(self, map_func, cycle_length, block_length)
|
||||
else:
|
||||
return ParallelInterleaveDataset(self, map_func, cycle_length,
|
||||
block_length, num_parallel_calls,
|
||||
deterministic)
|
||||
return ParallelInterleaveDataset(
|
||||
self,
|
||||
map_func,
|
||||
cycle_length,
|
||||
block_length,
|
||||
num_parallel_calls,
|
||||
deterministic=deterministic)
|
||||
|
||||
def filter(self, predicate):
|
||||
"""Filters this dataset according to `predicate`.
|
||||
@ -4042,6 +4046,8 @@ class ParallelInterleaveDataset(UnaryDataset):
|
||||
cycle_length,
|
||||
block_length,
|
||||
num_parallel_calls,
|
||||
buffer_output_elements=AUTOTUNE,
|
||||
prefetch_input_elements=AUTOTUNE,
|
||||
deterministic=None):
|
||||
"""See `Dataset.interleave()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
@ -4056,6 +4062,15 @@ class ParallelInterleaveDataset(UnaryDataset):
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._buffer_output_elements = ops.convert_to_tensor(
|
||||
buffer_output_elements,
|
||||
dtype=dtypes.int64,
|
||||
name="buffer_output_elements")
|
||||
self._prefetch_input_elements = ops.convert_to_tensor(
|
||||
prefetch_input_elements,
|
||||
dtype=dtypes.int64,
|
||||
name="prefetch_input_elements")
|
||||
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
|
||||
if deterministic is None:
|
||||
@ -4065,7 +4080,21 @@ class ParallelInterleaveDataset(UnaryDataset):
|
||||
else:
|
||||
deterministic_string = "false"
|
||||
|
||||
if deterministic is not None or compat.forward_compatible(2020, 2, 20):
|
||||
if (buffer_output_elements != AUTOTUNE or
|
||||
prefetch_input_elements != AUTOTUNE or
|
||||
compat.forward_compatible(2020, 3, 6)):
|
||||
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
self._num_parallel_calls,
|
||||
f=self._map_func.function,
|
||||
deterministic=deterministic_string,
|
||||
**self._flat_structure)
|
||||
elif deterministic is not None or compat.forward_compatible(2020, 2, 20):
|
||||
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||
|
@ -2636,6 +2636,10 @@ tf_module {
|
||||
name: "ParallelInterleaveDatasetV3"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelInterleaveDatasetV4"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelMapDataset"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
||||
|
@ -2636,6 +2636,10 @@ tf_module {
|
||||
name: "ParallelInterleaveDatasetV3"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelInterleaveDatasetV4"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelMapDataset"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user