Add API to control parallel interleave prefetching.
PiperOrigin-RevId: 294766661 Change-Id: I8061629522d19d408cd8b7a1981836a4ee958110
This commit is contained in:
parent
a7f1d52b03
commit
0c1ca5c674
@ -0,0 +1,100 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ParallelInterleaveDatasetV4"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "input_dataset"
|
||||||
|
description: <<END
|
||||||
|
Dataset that produces a stream of arguments for the function `f`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "other_arguments"
|
||||||
|
description: <<END
|
||||||
|
Additional arguments to pass to `f` beyond those produced by `input_dataset`.
|
||||||
|
Evaluated once when the dataset is instantiated.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "cycle_length"
|
||||||
|
description: <<END
|
||||||
|
Number of datasets (each created by applying `f` to the elements of
|
||||||
|
`input_dataset`) among which the `ParallelInterleaveDatasetV2` will cycle in a
|
||||||
|
round-robin fashion.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "block_length"
|
||||||
|
description: <<END
|
||||||
|
Number of elements at a time to produce from each interleaved invocation of a
|
||||||
|
dataset returned by `f`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "buffer_output_elements"
|
||||||
|
description: <<END
|
||||||
|
The number of elements each iterator being interleaved should buffer (similar
|
||||||
|
to the `.prefetch()` transformation for each interleaved iterator).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "prefetch_input_elements"
|
||||||
|
description: <<END
|
||||||
|
Determines the number of iterators to prefetch, allowing buffers to warm up and
|
||||||
|
data to be pre-fetched without blocking the main thread.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "num_parallel_calls"
|
||||||
|
description: <<END
|
||||||
|
Determines the number of threads that should be used for fetching data from
|
||||||
|
input datasets in parallel. The Python API `tf.data.experimental.AUTOTUNE`
|
||||||
|
constant can be used to indicate that the level of parallelism should be autotuned.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "f"
|
||||||
|
description: <<END
|
||||||
|
A function mapping elements of `input_dataset`, concatenated with
|
||||||
|
`other_arguments`, to a Dataset variant that contains elements matching
|
||||||
|
`output_types` and `output_shapes`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "deterministic"
|
||||||
|
description: <<END
|
||||||
|
A string indicating the op-level determinism to use. Deterministic controls
|
||||||
|
whether the interleave is allowed to return elements out of order if the next
|
||||||
|
element to be returned isn't available, but a later element is. Options are
|
||||||
|
"true", "false", and "default". "default" indicates that determinism should be
|
||||||
|
decided by the `experimental_deterministic` parameter of `tf.data.Options`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Targuments"
|
||||||
|
description: <<END
|
||||||
|
Types of the elements of `other_arguments`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "output_types"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "output_shapes"
|
||||||
|
}
|
||||||
|
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
|
||||||
|
description: <<END
|
||||||
|
The resulting dataset is similar to the `InterleaveDataset`, except that the
|
||||||
|
dataset will fetch records from the interleaved datasets in parallel.
|
||||||
|
|
||||||
|
The `tf.data` Python API creates instances of this op from
|
||||||
|
`Dataset.interleave()` when the `num_parallel_calls` parameter of that method
|
||||||
|
is set to any value other than `None`.
|
||||||
|
|
||||||
|
By default, the output of this dataset will be deterministic, which may result
|
||||||
|
in the dataset blocking if the next data item to be returned isn't available.
|
||||||
|
In order to avoid head-of-line blocking, one can either set the `deterministic`
|
||||||
|
attribute to "false", or leave it as "default" and set the
|
||||||
|
`experimental_deterministic` parameter of `tf.data.Options` to `False`.
|
||||||
|
This can improve performance at the expense of non-determinism.
|
||||||
|
END
|
||||||
|
}
|
@ -89,13 +89,14 @@ constexpr std::array<const char*, 28> kPassThroughOps = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
|
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
|
||||||
constexpr std::array<const char*, 6> kFuncDatasetOps = {
|
constexpr std::array<const char*, 7> kFuncDatasetOps = {
|
||||||
"ExperimentalParallelInterleaveDataset",
|
"ExperimentalParallelInterleaveDataset",
|
||||||
"FlatMapDataset",
|
"FlatMapDataset",
|
||||||
"InterleaveDataset",
|
"InterleaveDataset",
|
||||||
"ParallelInterleaveDataset",
|
"ParallelInterleaveDataset",
|
||||||
"ParallelInterleaveDatasetV2",
|
"ParallelInterleaveDatasetV2",
|
||||||
"ParallelInterleaveDatasetV3"
|
"ParallelInterleaveDatasetV3",
|
||||||
|
"ParallelInterleaveDatasetV4"
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
|
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
|
||||||
|
@ -33,12 +33,10 @@ namespace {
|
|||||||
constexpr char kLegacyAutotune[] = "legacy_autotune";
|
constexpr char kLegacyAutotune[] = "legacy_autotune";
|
||||||
constexpr char kPrefetchDataset[] = "PrefetchDataset";
|
constexpr char kPrefetchDataset[] = "PrefetchDataset";
|
||||||
|
|
||||||
constexpr std::array<const char*, 5> kAsyncDatasetOps = {
|
constexpr std::array<const char*, 6> kAsyncDatasetOps = {
|
||||||
"ExperimentalMapAndBatchDataset",
|
"ExperimentalMapAndBatchDataset", "ParallelMapDataset",
|
||||||
"ParallelMapDataset",
|
"ParallelInterleaveDatasetV2", "ParallelInterleaveDatasetV3",
|
||||||
"ParallelInterleaveDatasetV2",
|
"ParallelInterleaveDatasetV4", "MapAndBatchDataset",
|
||||||
"ParallelInterleaveDatasetV3",
|
|
||||||
"MapAndBatchDataset",
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -32,8 +32,9 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
|
|||||||
"ParseExampleDataset",
|
"ParseExampleDataset",
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr std::array<const char*, 1> kDeterministicAttrOps = {
|
constexpr std::array<const char*, 2> kDeterministicAttrOps = {
|
||||||
"ParallelInterleaveDatasetV3",
|
"ParallelInterleaveDatasetV3",
|
||||||
|
"ParallelInterleaveDatasetV4",
|
||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
@ -59,6 +59,10 @@ namespace data {
|
|||||||
ParallelInterleaveDatasetOp::kCycleLength;
|
ParallelInterleaveDatasetOp::kCycleLength;
|
||||||
/* static */ constexpr const char* const
|
/* static */ constexpr const char* const
|
||||||
ParallelInterleaveDatasetOp::kBlockLength;
|
ParallelInterleaveDatasetOp::kBlockLength;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
ParallelInterleaveDatasetOp::kBufferOutputElements;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
ParallelInterleaveDatasetOp::kPrefetchInputElements;
|
||||||
/* static */ constexpr const char* const
|
/* static */ constexpr const char* const
|
||||||
ParallelInterleaveDatasetOp::kNumParallelCalls;
|
ParallelInterleaveDatasetOp::kNumParallelCalls;
|
||||||
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
|
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
|
||||||
@ -91,34 +95,64 @@ constexpr char kSizeSuffix[] = ".size";
|
|||||||
constexpr char kInputsSuffix[] = ".inputs";
|
constexpr char kInputsSuffix[] = ".inputs";
|
||||||
constexpr char kIsReadySuffix[] = ".is_ready";
|
constexpr char kIsReadySuffix[] = ".is_ready";
|
||||||
|
|
||||||
// `kCyclePrefetchFactor * cycle_length` is the number of future cycle elements
|
constexpr char kParallelInterleaveDatasetV2[] = "ParallelInterleaveDatasetV2";
|
||||||
// that will be prefetched ahead of time. The purpose of prefetching future
|
constexpr char kParallelInterleaveDatasetV3[] = "ParallelInterleaveDatasetV3";
|
||||||
// cycle elements is to overlap expensive initialization (e.g. opening of a
|
constexpr char kParallelInterleaveDatasetV4[] = "ParallelInterleaveDatasetV4";
|
||||||
// remote file) with other computation.
|
|
||||||
constexpr double kCyclePrefetchFactor = 2.0L;
|
|
||||||
|
|
||||||
// `kPerIteratorPrefetchFactor * block_length + 1` is the number of per-iterator
|
// `kCyclePrefetchFactor * cycle_length` is the default number of future cycle
|
||||||
// results that will be prefetched ahead of time. The `+ 1` is to match the
|
// elements that will be prefetched ahead of time. The purpose of prefetching
|
||||||
// behavior of the original autotune implementation.
|
// future cycle elements is to overlap expensive initialization (e.g. opening of
|
||||||
constexpr double kPerIteratorPrefetchFactor = 2.0L;
|
// a remote file) with other computation.
|
||||||
|
constexpr double kDefaultCyclePrefetchFactor = 2.0L;
|
||||||
|
|
||||||
|
// `kPerIteratorPrefetchFactor * block_length + 1` is the defualt number of
|
||||||
|
// per-iterator results that will be prefetched ahead of time. The `+ 1` is to
|
||||||
|
// match the behavior of the original implementation.
|
||||||
|
constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L;
|
||||||
|
|
||||||
// Period between reporting dataset statistics.
|
// Period between reporting dataset statistics.
|
||||||
constexpr int kStatsReportingPeriodMillis = 1000;
|
constexpr int kStatsReportingPeriodMillis = 1000;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements,
|
||||||
|
int64 block_length) {
|
||||||
|
if (configured_buffer_output_elements != model::kAutotune) {
|
||||||
|
return configured_buffer_output_elements;
|
||||||
|
}
|
||||||
|
return kDefaultPerIteratorPrefetchFactor * block_length + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 ComputePrefetchInputElements(int64 configured_prefetch_input_elements,
|
||||||
|
int64 cycle_length) {
|
||||||
|
if (configured_prefetch_input_elements != model::kAutotune) {
|
||||||
|
return configured_prefetch_input_elements;
|
||||||
|
}
|
||||||
|
return kDefaultCyclePrefetchFactor * cycle_length;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 OpVersionFromOpName(absl::string_view op_name) {
|
||||||
|
if (op_name == kParallelInterleaveDatasetV2) {
|
||||||
|
return 2;
|
||||||
|
} else if (op_name == kParallelInterleaveDatasetV3) {
|
||||||
|
return 3;
|
||||||
|
} else {
|
||||||
|
DCHECK_EQ(op_name, kParallelInterleaveDatasetV4);
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// The motivation for creating an alternative implementation of parallel
|
// The motivation for creating an alternative implementation of parallel
|
||||||
// interleave is to decouple the degree of parallelism from the cycle length.
|
// interleave is to decouple the degree of parallelism from the cycle length.
|
||||||
// This makes it possible to change the degree of parallelism (e.g. through
|
// This makes it possible to change the degree of parallelism (e.g. through
|
||||||
// auto-tuning) without changing the cycle length (which would change the order
|
// auto-tuning) without changing the cycle length (which would change the order
|
||||||
// in which elements are produced).
|
// in which elements are produced).
|
||||||
//
|
|
||||||
// Furthermore, this class favors modularity over extended functionality. In
|
|
||||||
// particular, it refrains from implementing configurable buffering of output
|
|
||||||
// elements and prefetching of input iterators.
|
|
||||||
class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
||||||
int64 block_length, int64 num_parallel_calls,
|
int64 block_length, int64 buffer_output_elements,
|
||||||
|
int64 prefetch_input_elements, int64 num_parallel_calls,
|
||||||
DeterminismPolicy deterministic, const DataTypeVector& output_types,
|
DeterminismPolicy deterministic, const DataTypeVector& output_types,
|
||||||
const std::vector<PartialTensorShape>& output_shapes, int op_version)
|
const std::vector<PartialTensorShape>& output_shapes, int op_version)
|
||||||
: DatasetBase(DatasetContext(ctx)),
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
@ -126,6 +160,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
captured_func_(std::move(captured_func)),
|
captured_func_(std::move(captured_func)),
|
||||||
cycle_length_(cycle_length),
|
cycle_length_(cycle_length),
|
||||||
block_length_(block_length),
|
block_length_(block_length),
|
||||||
|
buffer_output_elements_(
|
||||||
|
ComputeBufferOutputElements(buffer_output_elements, block_length)),
|
||||||
|
prefetch_input_elements_(ComputePrefetchInputElements(
|
||||||
|
prefetch_input_elements, cycle_length)),
|
||||||
num_parallel_calls_(num_parallel_calls),
|
num_parallel_calls_(num_parallel_calls),
|
||||||
deterministic_(deterministic),
|
deterministic_(deterministic),
|
||||||
output_types_(output_types),
|
output_types_(output_types),
|
||||||
@ -179,19 +217,44 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
DatasetGraphDefBuilder* b,
|
DatasetGraphDefBuilder* b,
|
||||||
Node** output) const override {
|
Node** output) const override {
|
||||||
|
std::vector<std::pair<size_t, Node*>> inputs;
|
||||||
|
std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
|
||||||
|
int input_index = 0;
|
||||||
|
|
||||||
Node* input_node;
|
Node* input_node;
|
||||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
|
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
|
||||||
Node* cycle_length_node;
|
inputs.emplace_back(input_index++, input_node);
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
|
|
||||||
Node* block_length_node;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
|
|
||||||
Node* num_parallel_calls_node;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
|
|
||||||
std::vector<Node*> other_arguments;
|
std::vector<Node*> other_arguments;
|
||||||
DataTypeVector other_arguments_types;
|
DataTypeVector other_arguments_types;
|
||||||
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
|
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
|
||||||
&other_arguments_types));
|
&other_arguments_types));
|
||||||
|
list_inputs.emplace_back(input_index++, other_arguments);
|
||||||
|
|
||||||
|
Node* cycle_length_node;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
|
||||||
|
inputs.emplace_back(input_index++, cycle_length_node);
|
||||||
|
|
||||||
|
Node* block_length_node;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
|
||||||
|
inputs.emplace_back(input_index++, block_length_node);
|
||||||
|
|
||||||
|
if (op_version_ >= 4) {
|
||||||
|
Node* buffer_output_elements_node;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
|
||||||
|
inputs.emplace_back(input_index++, buffer_output_elements_node);
|
||||||
|
|
||||||
|
Node* prefetch_input_elements_node;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
|
||||||
|
&prefetch_input_elements_node));
|
||||||
|
inputs.emplace_back(input_index++, prefetch_input_elements_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* num_parallel_calls_node;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
|
||||||
|
inputs.emplace_back(input_index++, num_parallel_calls_node);
|
||||||
|
|
||||||
std::vector<std::pair<StringPiece, AttrValue>> attrs;
|
std::vector<std::pair<StringPiece, AttrValue>> attrs;
|
||||||
AttrValue f;
|
AttrValue f;
|
||||||
@ -207,18 +270,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
|
b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
|
||||||
attrs.emplace_back(kSloppy, sloppy_attr);
|
attrs.emplace_back(kSloppy, sloppy_attr);
|
||||||
}
|
}
|
||||||
if (op_version_ == 3) {
|
if (op_version_ >= 3) {
|
||||||
AttrValue deterministic_attr;
|
AttrValue deterministic_attr;
|
||||||
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
|
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
|
||||||
attrs.emplace_back(kDeterministic, deterministic_attr);
|
attrs.emplace_back(kDeterministic, deterministic_attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(b->AddDataset(this,
|
TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
|
||||||
{{0, input_node},
|
|
||||||
{2, cycle_length_node},
|
|
||||||
{3, block_length_node},
|
|
||||||
{4, num_parallel_calls_node}},
|
|
||||||
{{1, other_arguments}}, attrs, output));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,12 +285,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
public:
|
public:
|
||||||
ParallelInterleaveIterator(const Params& params, bool deterministic)
|
ParallelInterleaveIterator(const Params& params, bool deterministic)
|
||||||
: DatasetIterator<Dataset>(params),
|
: DatasetIterator<Dataset>(params),
|
||||||
per_iterator_prefetch_(
|
|
||||||
static_cast<int>(params.dataset->block_length_ *
|
|
||||||
kPerIteratorPrefetchFactor) +
|
|
||||||
1),
|
|
||||||
future_elements_prefetch_(static_cast<int>(
|
|
||||||
params.dataset->cycle_length_ * kCyclePrefetchFactor)),
|
|
||||||
mu_(std::make_shared<mutex>()),
|
mu_(std::make_shared<mutex>()),
|
||||||
num_parallel_calls_cond_var_(std::make_shared<condition_variable>()),
|
num_parallel_calls_cond_var_(std::make_shared<condition_variable>()),
|
||||||
num_parallel_calls_(std::make_shared<model::SharedState>(
|
num_parallel_calls_(std::make_shared<model::SharedState>(
|
||||||
@ -257,7 +309,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
// collection, `cycle_length_` threads for the current workers, and
|
// collection, `cycle_length_` threads for the current workers, and
|
||||||
// `future_elements_prefetch_` for the future workers.
|
// `future_elements_prefetch_` for the future workers.
|
||||||
int max_current_workers = dataset()->cycle_length_;
|
int max_current_workers = dataset()->cycle_length_;
|
||||||
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
|
int future_workers =
|
||||||
|
dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
|
||||||
int num_threads = 1 + max_current_workers + future_workers;
|
int num_threads = 1 + max_current_workers + future_workers;
|
||||||
if (ctx->stats_aggregator()) {
|
if (ctx->stats_aggregator()) {
|
||||||
num_threads++;
|
num_threads++;
|
||||||
@ -656,7 +709,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
// `cycle_length_` future workers to guarantee that whenever
|
// `cycle_length_` future workers to guarantee that whenever
|
||||||
// `future_element_.size() < future_elements_prefetch_`, there will be a
|
// `future_element_.size() < future_elements_prefetch_`, there will be a
|
||||||
// future worker available to create a new future element.
|
// future worker available to create a new future element.
|
||||||
int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
|
int future_workers =
|
||||||
|
dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
|
||||||
{
|
{
|
||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
initial_current_workers = num_parallel_calls_->value;
|
initial_current_workers = num_parallel_calls_->value;
|
||||||
@ -800,8 +854,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
current_workers_cond_var_.notify_one();
|
current_workers_cond_var_.notify_one();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
while (!cancelled_ &&
|
while (!cancelled_ && (future_elements_.size() >=
|
||||||
(future_elements_.size() >= future_elements_prefetch_ ||
|
dataset()->prefetch_input_elements_ ||
|
||||||
wait_for_checkpoint_)) {
|
wait_for_checkpoint_)) {
|
||||||
WaitWorkerThread(&future_workers_cond_var_, &l);
|
WaitWorkerThread(&future_workers_cond_var_, &l);
|
||||||
}
|
}
|
||||||
@ -860,7 +914,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
mutex_lock l(*mu_);
|
mutex_lock l(*mu_);
|
||||||
element->results.push_back(std::move(result));
|
element->results.push_back(std::move(result));
|
||||||
NotifyElementUpdate(element);
|
NotifyElementUpdate(element);
|
||||||
if (element->results.size() == per_iterator_prefetch_) {
|
if (element->results.size() == dataset()->buffer_output_elements_) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -954,7 +1008,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return element->iterator &&
|
return element->iterator &&
|
||||||
element->results.size() < per_iterator_prefetch_;
|
element->results.size() < dataset()->buffer_output_elements_;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void IncrementCurrentWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
inline void IncrementCurrentWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
@ -1311,9 +1365,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
// its elements are null.
|
// its elements are null.
|
||||||
int64 last_valid_current_element_ GUARDED_BY(mu_) = -1;
|
int64 last_valid_current_element_ GUARDED_BY(mu_) = -1;
|
||||||
|
|
||||||
const int per_iterator_prefetch_;
|
|
||||||
const int future_elements_prefetch_;
|
|
||||||
|
|
||||||
// Identifies whether the current_elements_ vector has been initialized.
|
// Identifies whether the current_elements_ vector has been initialized.
|
||||||
bool initial_elements_created_ GUARDED_BY(mu_) = false;
|
bool initial_elements_created_ GUARDED_BY(mu_) = false;
|
||||||
|
|
||||||
@ -1421,6 +1472,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||||
const int64 cycle_length_;
|
const int64 cycle_length_;
|
||||||
const int64 block_length_;
|
const int64 block_length_;
|
||||||
|
const int64 buffer_output_elements_;
|
||||||
|
const int64 prefetch_input_elements_;
|
||||||
const int64 num_parallel_calls_;
|
const int64 num_parallel_calls_;
|
||||||
const DeterminismPolicy deterministic_;
|
const DeterminismPolicy deterministic_;
|
||||||
const DataTypeVector output_types_;
|
const DataTypeVector output_types_;
|
||||||
@ -1431,7 +1484,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
|
|
||||||
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
|
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
|
||||||
OpKernelConstruction* ctx)
|
OpKernelConstruction* ctx)
|
||||||
: UnaryDatasetOpKernel(ctx), op_version_(ctx->HasAttr(kSloppy) ? 2 : 3) {
|
: UnaryDatasetOpKernel(ctx),
|
||||||
|
op_version_(OpVersionFromOpName(ctx->def().op())) {
|
||||||
FunctionMetadata::Params params;
|
FunctionMetadata::Params params;
|
||||||
params.is_multi_device_function = true;
|
params.is_multi_device_function = true;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
@ -1448,7 +1502,7 @@ ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
|
|||||||
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
|
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (op_version_ == 3) {
|
if (op_version_ >= 3) {
|
||||||
std::string deterministic;
|
std::string deterministic;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
@ -1472,6 +1526,26 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
|||||||
OP_REQUIRES(ctx, block_length > 0,
|
OP_REQUIRES(ctx, block_length > 0,
|
||||||
errors::InvalidArgument("`block_length` must be > 0"));
|
errors::InvalidArgument("`block_length` must be > 0"));
|
||||||
|
|
||||||
|
int64 buffer_output_elements = model::kAutotune;
|
||||||
|
int64 prefetch_input_elements = model::kAutotune;
|
||||||
|
if (op_version_ >= 4) {
|
||||||
|
OP_REQUIRES(ctx,
|
||||||
|
buffer_output_elements == model::kAutotune ||
|
||||||
|
buffer_output_elements >= 0,
|
||||||
|
errors::InvalidArgument("`buffer_output_elements` must be ",
|
||||||
|
model::kAutotune, " or >= 0 but is ",
|
||||||
|
buffer_output_elements));
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
|
||||||
|
&prefetch_input_elements));
|
||||||
|
OP_REQUIRES(ctx,
|
||||||
|
prefetch_input_elements == model::kAutotune ||
|
||||||
|
prefetch_input_elements >= 0,
|
||||||
|
errors::InvalidArgument("`prefetch_input_elements` must be ",
|
||||||
|
model::kAutotune, " or >= 0 but is ",
|
||||||
|
prefetch_input_elements));
|
||||||
|
}
|
||||||
|
|
||||||
int64 num_parallel_calls = 0;
|
int64 num_parallel_calls = 0;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
|
ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
|
||||||
@ -1492,18 +1566,22 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
|||||||
metrics::RecordTFDataAutotune(kDatasetType);
|
metrics::RecordTFDataAutotune(kDatasetType);
|
||||||
}
|
}
|
||||||
|
|
||||||
*output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
|
*output = new Dataset(
|
||||||
block_length, num_parallel_calls, deterministic_,
|
ctx, input, std::move(captured_func), cycle_length, block_length,
|
||||||
output_types_, output_shapes_, op_version_);
|
buffer_output_elements, prefetch_input_elements, num_parallel_calls,
|
||||||
|
deterministic_, output_types_, output_shapes_, op_version_);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV2).Device(DEVICE_CPU),
|
||||||
ParallelInterleaveDatasetOp);
|
ParallelInterleaveDatasetOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV3").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV3).Device(DEVICE_CPU),
|
||||||
ParallelInterleaveDatasetOp);
|
ParallelInterleaveDatasetOp);
|
||||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV2");
|
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV4).Device(DEVICE_CPU),
|
||||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV3");
|
ParallelInterleaveDatasetOp);
|
||||||
|
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV2);
|
||||||
|
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV3);
|
||||||
|
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV4);
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -29,6 +29,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
static constexpr const char* const kOtherArguments = "other_arguments";
|
static constexpr const char* const kOtherArguments = "other_arguments";
|
||||||
static constexpr const char* const kCycleLength = "cycle_length";
|
static constexpr const char* const kCycleLength = "cycle_length";
|
||||||
static constexpr const char* const kBlockLength = "block_length";
|
static constexpr const char* const kBlockLength = "block_length";
|
||||||
|
static constexpr const char* const kBufferOutputElements =
|
||||||
|
"buffer_output_elements";
|
||||||
|
static constexpr const char* const kPrefetchInputElements =
|
||||||
|
"prefetch_input_elements";
|
||||||
static constexpr const char* const kNumParallelCalls = "num_parallel_calls";
|
static constexpr const char* const kNumParallelCalls = "num_parallel_calls";
|
||||||
static constexpr const char* const kFunc = "f";
|
static constexpr const char* const kFunc = "f";
|
||||||
static constexpr const char* const kTarguments = "Targuments";
|
static constexpr const char* const kTarguments = "Targuments";
|
||||||
|
@ -18,14 +18,15 @@ namespace data {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kNodeName[] = "parallel_interleave_dataset";
|
constexpr char kNodeName[] = "parallel_interleave_dataset";
|
||||||
constexpr int kOpVersion = 3;
|
constexpr int kOpVersion = 4;
|
||||||
|
|
||||||
class ParallelInterleaveDatasetParams : public DatasetParams {
|
class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||||
public:
|
public:
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ParallelInterleaveDatasetParams(
|
ParallelInterleaveDatasetParams(
|
||||||
T input_dataset_params, std::vector<Tensor> other_arguments,
|
T input_dataset_params, std::vector<Tensor> other_arguments,
|
||||||
int64 cycle_length, int64 block_length, int64 num_parallel_calls,
|
int64 cycle_length, int64 block_length, int64 buffer_output_elements,
|
||||||
|
int64 prefetch_input_elements, int64 num_parallel_calls,
|
||||||
FunctionDefHelper::AttrValueWrapper func,
|
FunctionDefHelper::AttrValueWrapper func,
|
||||||
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
|
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
|
||||||
const DataTypeVector& output_dtypes,
|
const DataTypeVector& output_dtypes,
|
||||||
@ -36,6 +37,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
|||||||
other_arguments_(std::move(other_arguments)),
|
other_arguments_(std::move(other_arguments)),
|
||||||
cycle_length_(cycle_length),
|
cycle_length_(cycle_length),
|
||||||
block_length_(block_length),
|
block_length_(block_length),
|
||||||
|
buffer_output_elements_(buffer_output_elements),
|
||||||
|
prefetch_input_elements_(prefetch_input_elements),
|
||||||
num_parallel_calls_(num_parallel_calls),
|
num_parallel_calls_(num_parallel_calls),
|
||||||
func_(std::move(func)),
|
func_(std::move(func)),
|
||||||
func_lib_(std::move(func_lib)),
|
func_lib_(std::move(func_lib)),
|
||||||
@ -56,6 +59,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
|||||||
CreateTensor<int64>(TensorShape({}), {cycle_length_}));
|
CreateTensor<int64>(TensorShape({}), {cycle_length_}));
|
||||||
input_tensors.emplace_back(
|
input_tensors.emplace_back(
|
||||||
CreateTensor<int64>(TensorShape({}), {block_length_}));
|
CreateTensor<int64>(TensorShape({}), {block_length_}));
|
||||||
|
input_tensors.emplace_back(
|
||||||
|
CreateTensor<int64>(TensorShape({}), {buffer_output_elements_}));
|
||||||
|
input_tensors.emplace_back(
|
||||||
|
CreateTensor<int64>(TensorShape({}), {prefetch_input_elements_}));
|
||||||
input_tensors.emplace_back(
|
input_tensors.emplace_back(
|
||||||
CreateTensor<int64>(TensorShape({}), {num_parallel_calls_}));
|
CreateTensor<int64>(TensorShape({}), {num_parallel_calls_}));
|
||||||
return input_tensors;
|
return input_tensors;
|
||||||
@ -69,6 +76,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
|||||||
}
|
}
|
||||||
input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength);
|
input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength);
|
||||||
input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength);
|
input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength);
|
||||||
|
input_names->emplace_back(
|
||||||
|
ParallelInterleaveDatasetOp::kBufferOutputElements);
|
||||||
|
input_names->emplace_back(
|
||||||
|
ParallelInterleaveDatasetOp::kPrefetchInputElements);
|
||||||
input_names->emplace_back(ParallelInterleaveDatasetOp::kNumParallelCalls);
|
input_names->emplace_back(ParallelInterleaveDatasetOp::kNumParallelCalls);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -93,6 +104,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
|||||||
std::vector<Tensor> other_arguments_;
|
std::vector<Tensor> other_arguments_;
|
||||||
int64 cycle_length_;
|
int64 cycle_length_;
|
||||||
int64 block_length_;
|
int64 block_length_;
|
||||||
|
int64 buffer_output_elements_;
|
||||||
|
int64 prefetch_input_elements_;
|
||||||
int64 num_parallel_calls_;
|
int64 num_parallel_calls_;
|
||||||
FunctionDefHelper::AttrValueWrapper func_;
|
FunctionDefHelper::AttrValueWrapper func_;
|
||||||
std::vector<FunctionDef> func_lib_;
|
std::vector<FunctionDef> func_lib_;
|
||||||
@ -111,8 +124,6 @@ FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc(
|
|||||||
{"output_shapes", output_shapes}});
|
{"output_shapes", output_shapes}});
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 1: cycle_length = 1, block_length = 1, num_parallel_calls = 1,
|
|
||||||
// sloppy = false
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||||
@ -123,6 +134,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/1,
|
/*cycle_length=*/1,
|
||||||
/*block_length=*/1,
|
/*block_length=*/1,
|
||||||
|
/*buffer_output_elements=*/model::kAutotune,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
/*num_parallel_calls=*/1,
|
/*num_parallel_calls=*/1,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -136,8 +149,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 2: cycle_length = 2, block_length = 1, num_parallel_calls = 2,
|
|
||||||
// sloppy = false
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||||
@ -148,6 +159,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/2,
|
/*cycle_length=*/2,
|
||||||
/*block_length=*/1,
|
/*block_length=*/1,
|
||||||
|
/*buffer_output_elements=*/0,
|
||||||
|
/*prefetch_input_elements=*/0,
|
||||||
/*num_parallel_calls=*/2,
|
/*num_parallel_calls=*/2,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -161,8 +174,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 3: cycle_length = 3, block_length = 1, num_parallel_calls = 2,
|
|
||||||
// sloppy = true
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||||
@ -173,6 +184,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/3,
|
/*cycle_length=*/3,
|
||||||
/*block_length=*/1,
|
/*block_length=*/1,
|
||||||
|
/*buffer_output_elements=*/0,
|
||||||
|
/*prefetch_input_elements=*/1,
|
||||||
/*num_parallel_calls=*/2,
|
/*num_parallel_calls=*/2,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -186,9 +199,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 4: cycle_length = 5, block_length = 1, num_parallel_calls = 4,
|
|
||||||
// sloppy = true
|
|
||||||
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||||
@ -199,6 +209,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/5,
|
/*cycle_length=*/5,
|
||||||
/*block_length=*/1,
|
/*block_length=*/1,
|
||||||
|
/*buffer_output_elements=*/1,
|
||||||
|
/*prefetch_input_elements=*/0,
|
||||||
/*num_parallel_calls=*/4,
|
/*num_parallel_calls=*/4,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -212,8 +224,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 5: cycle_length = 2, block_length = 2, num_parallel_calls = 1,
|
|
||||||
// sloppy = false
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<tstring>(
|
/*components=*/{CreateTensor<tstring>(
|
||||||
@ -224,6 +234,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/2,
|
/*cycle_length=*/2,
|
||||||
/*block_length=*/2,
|
/*block_length=*/2,
|
||||||
|
/*buffer_output_elements=*/2,
|
||||||
|
/*prefetch_input_elements=*/2,
|
||||||
/*num_parallel_calls=*/1,
|
/*num_parallel_calls=*/1,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -237,8 +249,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 6: cycle_length = 2, block_length = 3, num_parallel_calls = 2,
|
|
||||||
// sloppy = true
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<tstring>(
|
/*components=*/{CreateTensor<tstring>(
|
||||||
@ -249,6 +259,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/2,
|
/*cycle_length=*/2,
|
||||||
/*block_length=*/3,
|
/*block_length=*/3,
|
||||||
|
/*buffer_output_elements=*/100,
|
||||||
|
/*prefetch_input_elements=*/100,
|
||||||
/*num_parallel_calls=*/2,
|
/*num_parallel_calls=*/2,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -262,8 +274,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 7: cycle_length = 3, block_length = 2, num_parallel_calls = 2,
|
|
||||||
// sloppy = false
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<tstring>(
|
/*components=*/{CreateTensor<tstring>(
|
||||||
@ -274,6 +284,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/3,
|
/*cycle_length=*/3,
|
||||||
/*block_length=*/2,
|
/*block_length=*/2,
|
||||||
|
/*buffer_output_elements=*/model::kAutotune,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
/*num_parallel_calls=*/2,
|
/*num_parallel_calls=*/2,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -287,8 +299,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 8: cycle_length = 3, block_length = 3, num_parallel_calls = 3,
|
|
||||||
// sloppy = true
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<tstring>(
|
/*components=*/{CreateTensor<tstring>(
|
||||||
@ -299,6 +309,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/3,
|
/*cycle_length=*/3,
|
||||||
/*block_length=*/3,
|
/*block_length=*/3,
|
||||||
|
/*buffer_output_elements=*/model::kAutotune,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
/*num_parallel_calls=*/3,
|
/*num_parallel_calls=*/3,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -312,8 +324,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 9: cycle_length = 4, block_length = 4, num_parallel_calls = 4,
|
|
||||||
// sloppy = true
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<tstring>(
|
/*components=*/{CreateTensor<tstring>(
|
||||||
@ -324,6 +334,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/4,
|
/*cycle_length=*/4,
|
||||||
/*block_length=*/4,
|
/*block_length=*/4,
|
||||||
|
/*buffer_output_elements=*/model::kAutotune,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
/*num_parallel_calls=*/4,
|
/*num_parallel_calls=*/4,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -337,8 +349,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 10: cycle_length = 3, block_length = 3,
|
|
||||||
// num_parallel_calls = kAutotune, sloppy = true
|
|
||||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
|
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
/*components=*/{CreateTensor<tstring>(
|
/*components=*/{CreateTensor<tstring>(
|
||||||
@ -349,6 +359,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/4,
|
/*cycle_length=*/4,
|
||||||
/*block_length=*/4,
|
/*block_length=*/4,
|
||||||
|
/*buffer_output_elements=*/model::kAutotune,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
/*num_parallel_calls=*/model::kAutotune,
|
/*num_parallel_calls=*/model::kAutotune,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -372,6 +384,8 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/11,
|
/*cycle_length=*/11,
|
||||||
/*block_length=*/1,
|
/*block_length=*/1,
|
||||||
|
/*buffer_output_elements=*/model::kAutotune,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
/*num_parallel_calls=*/2,
|
/*num_parallel_calls=*/2,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -385,8 +399,6 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 11: cycle_length = 0, block_length = 1, num_parallel_calls = 2,
|
|
||||||
// sloppy = true
|
|
||||||
ParallelInterleaveDatasetParams
|
ParallelInterleaveDatasetParams
|
||||||
ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
|
ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
@ -398,6 +410,8 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/0,
|
/*cycle_length=*/0,
|
||||||
/*block_length=*/1,
|
/*block_length=*/1,
|
||||||
|
/*buffer_output_elements=*/model::kAutotune,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
/*num_parallel_calls=*/2,
|
/*num_parallel_calls=*/2,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -411,8 +425,6 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 12: cycle_length = 1, block_length = -1, num_parallel_calls = 2,
|
|
||||||
// sloppy = true
|
|
||||||
ParallelInterleaveDatasetParams
|
ParallelInterleaveDatasetParams
|
||||||
ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
|
ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
@ -424,6 +436,8 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/1,
|
/*cycle_length=*/1,
|
||||||
/*block_length=*/-1,
|
/*block_length=*/-1,
|
||||||
|
/*buffer_output_elements=*/model::kAutotune,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
/*num_parallel_calls=*/2,
|
/*num_parallel_calls=*/2,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -437,8 +451,6 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
|
|||||||
/*node_name=*/kNodeName);
|
/*node_name=*/kNodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
// test case 13: cycle_length = 1, block_length = 1, num_parallel_calls = -5,
|
|
||||||
// sloppy = true
|
|
||||||
ParallelInterleaveDatasetParams
|
ParallelInterleaveDatasetParams
|
||||||
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
|
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
|
||||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
@ -450,6 +462,60 @@ ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
|
|||||||
/*other_arguments=*/{},
|
/*other_arguments=*/{},
|
||||||
/*cycle_length=*/1,
|
/*cycle_length=*/1,
|
||||||
/*block_length=*/1,
|
/*block_length=*/1,
|
||||||
|
/*buffer_output_elements=*/model::kAutotune,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
|
/*num_parallel_calls=*/-5,
|
||||||
|
/*func=*/
|
||||||
|
MakeTensorSliceDatasetFunc(
|
||||||
|
DataTypeVector({DT_INT64}),
|
||||||
|
std::vector<PartialTensorShape>({PartialTensorShape({1})})),
|
||||||
|
/*func_lib=*/{test::function::MakeTensorSliceDataset()},
|
||||||
|
/*type_arguments=*/{},
|
||||||
|
/*output_dtypes=*/{DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({1})},
|
||||||
|
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
ParallelInterleaveDatasetParams
|
||||||
|
ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements() {
|
||||||
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||||
|
{0, 1, 2, 3, 4, 5, 6, 7, 8})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return ParallelInterleaveDatasetParams(
|
||||||
|
tensor_slice_dataset_params,
|
||||||
|
/*other_arguments=*/{},
|
||||||
|
/*cycle_length=*/1,
|
||||||
|
/*block_length=*/1,
|
||||||
|
/*buffer_output_elements=*/-2,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
|
/*num_parallel_calls=*/-5,
|
||||||
|
/*func=*/
|
||||||
|
MakeTensorSliceDatasetFunc(
|
||||||
|
DataTypeVector({DT_INT64}),
|
||||||
|
std::vector<PartialTensorShape>({PartialTensorShape({1})})),
|
||||||
|
/*func_lib=*/{test::function::MakeTensorSliceDataset()},
|
||||||
|
/*type_arguments=*/{},
|
||||||
|
/*output_dtypes=*/{DT_INT64},
|
||||||
|
/*output_shapes=*/{PartialTensorShape({1})},
|
||||||
|
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||||
|
/*node_name=*/kNodeName);
|
||||||
|
}
|
||||||
|
|
||||||
|
ParallelInterleaveDatasetParams
|
||||||
|
ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements() {
|
||||||
|
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||||
|
/*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
|
||||||
|
{0, 1, 2, 3, 4, 5, 6, 7, 8})},
|
||||||
|
/*node_name=*/"tensor_slice");
|
||||||
|
return ParallelInterleaveDatasetParams(
|
||||||
|
tensor_slice_dataset_params,
|
||||||
|
/*other_arguments=*/{},
|
||||||
|
/*cycle_length=*/1,
|
||||||
|
/*block_length=*/1,
|
||||||
|
/*buffer_output_elements=*/-2,
|
||||||
|
/*prefetch_input_elements=*/model::kAutotune,
|
||||||
/*num_parallel_calls=*/-5,
|
/*num_parallel_calls=*/-5,
|
||||||
/*func=*/
|
/*func=*/
|
||||||
MakeTensorSliceDatasetFunc(
|
MakeTensorSliceDatasetFunc(
|
||||||
@ -698,7 +764,10 @@ TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) {
|
|||||||
std::vector<ParallelInterleaveDatasetParams> invalid_params = {
|
std::vector<ParallelInterleaveDatasetParams> invalid_params = {
|
||||||
ParallelInterleaveDatasetParamsWithInvalidCycleLength(),
|
ParallelInterleaveDatasetParamsWithInvalidCycleLength(),
|
||||||
ParallelInterleaveDatasetParamsWithInvalidBlockLength(),
|
ParallelInterleaveDatasetParamsWithInvalidBlockLength(),
|
||||||
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls()};
|
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls(),
|
||||||
|
ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements(),
|
||||||
|
ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements(),
|
||||||
|
};
|
||||||
for (auto& dataset_params : invalid_params) {
|
for (auto& dataset_params : invalid_params) {
|
||||||
EXPECT_EQ(Initialize(dataset_params).code(),
|
EXPECT_EQ(Initialize(dataset_params).code(),
|
||||||
tensorflow::error::INVALID_ARGUMENT);
|
tensorflow::error::INVALID_ARGUMENT);
|
||||||
|
@ -227,6 +227,24 @@ REGISTER_OP("ParallelInterleaveDatasetV3")
|
|||||||
.Attr("output_shapes: list(shape) >= 1")
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
.SetShapeFn(shape_inference::ScalarShape);
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
|
// Like V3, but adds buffer_output_elements and prefetch_input_elements.
|
||||||
|
REGISTER_OP("ParallelInterleaveDatasetV4")
|
||||||
|
.Input("input_dataset: variant")
|
||||||
|
.Input("other_arguments: Targuments")
|
||||||
|
.Input("cycle_length: int64")
|
||||||
|
.Input("block_length: int64")
|
||||||
|
.Input("buffer_output_elements: int64")
|
||||||
|
.Input("prefetch_input_elements: int64")
|
||||||
|
.Input("num_parallel_calls: int64")
|
||||||
|
.Output("handle: variant")
|
||||||
|
.Attr("f: func")
|
||||||
|
// "true", "false", or "default".
|
||||||
|
.Attr("deterministic: string = 'default'")
|
||||||
|
.Attr("Targuments: list(type) >= 0")
|
||||||
|
.Attr("output_types: list(type) >= 1")
|
||||||
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
REGISTER_OP("FilterDataset")
|
REGISTER_OP("FilterDataset")
|
||||||
.Input("input_dataset: variant")
|
.Input("input_dataset: variant")
|
||||||
.Input("other_arguments: Targuments")
|
.Input("other_arguments: Targuments")
|
||||||
|
@ -64,6 +64,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
parallel_interleave = "ParallelInterleaveV2"
|
parallel_interleave = "ParallelInterleaveV2"
|
||||||
if compat.forward_compatible(2020, 2, 20):
|
if compat.forward_compatible(2020, 2, 20):
|
||||||
parallel_interleave = "ParallelInterleaveV3"
|
parallel_interleave = "ParallelInterleaveV3"
|
||||||
|
if compat.forward_compatible(2020, 3, 6):
|
||||||
|
parallel_interleave = "ParallelInterleaveV4"
|
||||||
dataset = dataset.apply(
|
dataset = dataset.apply(
|
||||||
testing.assert_next([parallel_interleave, "Prefetch", "FiniteTake"]))
|
testing.assert_next([parallel_interleave, "Prefetch", "FiniteTake"]))
|
||||||
dataset = dataset.interleave(
|
dataset = dataset.interleave(
|
||||||
@ -79,6 +81,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
parallel_interleave = "ParallelInterleaveV2"
|
parallel_interleave = "ParallelInterleaveV2"
|
||||||
if compat.forward_compatible(2020, 2, 20):
|
if compat.forward_compatible(2020, 2, 20):
|
||||||
parallel_interleave = "ParallelInterleaveV3"
|
parallel_interleave = "ParallelInterleaveV3"
|
||||||
|
if compat.forward_compatible(2020, 3, 6):
|
||||||
|
parallel_interleave = "ParallelInterleaveV4"
|
||||||
dataset = dataset.apply(
|
dataset = dataset.apply(
|
||||||
testing.assert_next([
|
testing.assert_next([
|
||||||
"ParallelMap", "Prefetch", parallel_interleave, "Prefetch",
|
"ParallelMap", "Prefetch", parallel_interleave, "Prefetch",
|
||||||
|
@ -1714,9 +1714,13 @@ name=None))
|
|||||||
if num_parallel_calls is None:
|
if num_parallel_calls is None:
|
||||||
return InterleaveDataset(self, map_func, cycle_length, block_length)
|
return InterleaveDataset(self, map_func, cycle_length, block_length)
|
||||||
else:
|
else:
|
||||||
return ParallelInterleaveDataset(self, map_func, cycle_length,
|
return ParallelInterleaveDataset(
|
||||||
block_length, num_parallel_calls,
|
self,
|
||||||
deterministic)
|
map_func,
|
||||||
|
cycle_length,
|
||||||
|
block_length,
|
||||||
|
num_parallel_calls,
|
||||||
|
deterministic=deterministic)
|
||||||
|
|
||||||
def filter(self, predicate):
|
def filter(self, predicate):
|
||||||
"""Filters this dataset according to `predicate`.
|
"""Filters this dataset according to `predicate`.
|
||||||
@ -4042,6 +4046,8 @@ class ParallelInterleaveDataset(UnaryDataset):
|
|||||||
cycle_length,
|
cycle_length,
|
||||||
block_length,
|
block_length,
|
||||||
num_parallel_calls,
|
num_parallel_calls,
|
||||||
|
buffer_output_elements=AUTOTUNE,
|
||||||
|
prefetch_input_elements=AUTOTUNE,
|
||||||
deterministic=None):
|
deterministic=None):
|
||||||
"""See `Dataset.interleave()` for details."""
|
"""See `Dataset.interleave()` for details."""
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
@ -4056,6 +4062,15 @@ class ParallelInterleaveDataset(UnaryDataset):
|
|||||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||||
self._block_length = ops.convert_to_tensor(
|
self._block_length = ops.convert_to_tensor(
|
||||||
block_length, dtype=dtypes.int64, name="block_length")
|
block_length, dtype=dtypes.int64, name="block_length")
|
||||||
|
self._buffer_output_elements = ops.convert_to_tensor(
|
||||||
|
buffer_output_elements,
|
||||||
|
dtype=dtypes.int64,
|
||||||
|
name="buffer_output_elements")
|
||||||
|
self._prefetch_input_elements = ops.convert_to_tensor(
|
||||||
|
prefetch_input_elements,
|
||||||
|
dtype=dtypes.int64,
|
||||||
|
name="prefetch_input_elements")
|
||||||
|
|
||||||
self._num_parallel_calls = ops.convert_to_tensor(
|
self._num_parallel_calls = ops.convert_to_tensor(
|
||||||
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
|
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
|
||||||
if deterministic is None:
|
if deterministic is None:
|
||||||
@ -4065,7 +4080,21 @@ class ParallelInterleaveDataset(UnaryDataset):
|
|||||||
else:
|
else:
|
||||||
deterministic_string = "false"
|
deterministic_string = "false"
|
||||||
|
|
||||||
if deterministic is not None or compat.forward_compatible(2020, 2, 20):
|
if (buffer_output_elements != AUTOTUNE or
|
||||||
|
prefetch_input_elements != AUTOTUNE or
|
||||||
|
compat.forward_compatible(2020, 3, 6)):
|
||||||
|
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
|
||||||
|
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
|
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||||
|
self._cycle_length,
|
||||||
|
self._block_length,
|
||||||
|
self._buffer_output_elements,
|
||||||
|
self._prefetch_input_elements,
|
||||||
|
self._num_parallel_calls,
|
||||||
|
f=self._map_func.function,
|
||||||
|
deterministic=deterministic_string,
|
||||||
|
**self._flat_structure)
|
||||||
|
elif deterministic is not None or compat.forward_compatible(2020, 2, 20):
|
||||||
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
|
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
|
||||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||||
|
@ -2636,6 +2636,10 @@ tf_module {
|
|||||||
name: "ParallelInterleaveDatasetV3"
|
name: "ParallelInterleaveDatasetV3"
|
||||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ParallelInterleaveDatasetV4"
|
||||||
|
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ParallelMapDataset"
|
name: "ParallelMapDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
||||||
|
@ -2636,6 +2636,10 @@ tf_module {
|
|||||||
name: "ParallelInterleaveDatasetV3"
|
name: "ParallelInterleaveDatasetV3"
|
||||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ParallelInterleaveDatasetV4"
|
||||||
|
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ParallelMapDataset"
|
name: "ParallelMapDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user