Support None option for experimental_interleave sloppiness.

If sloppy=None, the transform will use the experimental_deterministic option to
determine whether to use sloppy behavior.

PiperOrigin-RevId: 295252219
Change-Id: I93b4ac182ea4edd8666364827d226593a3d3c7bd
This commit is contained in:
Andrew Audibert 2020-02-14 16:28:36 -08:00 committed by TensorFlower Gardener
parent 33f00e722e
commit 745ed4714d
13 changed files with 284 additions and 101 deletions

View File

@ -0,0 +1,22 @@
op {
graph_op_name: "LegacyParallelInterleaveDatasetV2"
visibility: HIDDEN
attr {
name: "f"
description: <<END
A function mapping elements of `input_dataset`, concatenated with
`other_arguments`, to a Dataset variant that contains elements matching
`output_types` and `output_shapes`.
END
}
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
description: <<END
The resulting dataset is similar to the `InterleaveDataset`, with the exception
that if retrieving the next value from a dataset would cause the requester to
block, it will skip that input dataset. This dataset is especially useful
when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
allows the training step to proceed so long as some data is available.
!! WARNING !! This dataset is not deterministic!
END
}

View File

@ -90,10 +90,11 @@ constexpr std::array<const char*, 29> kPassThroughOps = {
}; };
// TODO(frankchn): Process functions within kFuncDatasetOps as well. // TODO(frankchn): Process functions within kFuncDatasetOps as well.
constexpr std::array<const char*, 7> kFuncDatasetOps = { constexpr std::array<const char*, 8> kFuncDatasetOps = {
"ExperimentalParallelInterleaveDataset", "ExperimentalParallelInterleaveDataset",
"FlatMapDataset", "FlatMapDataset",
"InterleaveDataset", "InterleaveDataset",
"LegacyParallelInterleaveDatasetV2",
"ParallelInterleaveDataset", "ParallelInterleaveDataset",
"ParallelInterleaveDatasetV2", "ParallelInterleaveDatasetV2",
"ParallelInterleaveDatasetV3", "ParallelInterleaveDatasetV3",

View File

@ -33,10 +33,10 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
}; };
constexpr std::array<const char*, 4> kDeterministicAttrOps = { constexpr std::array<const char*, 4> kDeterministicAttrOps = {
"LegacyParallelInterleaveDatasetV2",
"ParallelInterleaveDatasetV3", "ParallelInterleaveDatasetV3",
"ParallelInterleaveDatasetV4", "ParallelInterleaveDatasetV4",
"ParallelMapDatasetV2", "ParallelMapDatasetV2",
"ParseExampleDatasetV2",
}; };
} // anonymous namespace } // anonymous namespace

View File

@ -46,6 +46,8 @@ namespace experimental {
ParallelInterleaveDatasetOp::kCycleLength; ParallelInterleaveDatasetOp::kCycleLength;
/* static */ constexpr const char* const /* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kBlockLength; ParallelInterleaveDatasetOp::kBlockLength;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kDeterministic;
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy; /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
/* static */ constexpr const char* const /* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kBufferOutputElements; ParallelInterleaveDatasetOp::kBufferOutputElements;
@ -90,15 +92,16 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
public: public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, Dataset(OpKernelContext* ctx, const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, bool sloppy, int64 buffer_output_elements, int64 block_length, DeterminismPolicy deterministic,
int64 prefetch_input_elements, const DataTypeVector& output_types, int64 buffer_output_elements, int64 prefetch_input_elements,
const std::vector<PartialTensorShape>& output_shapes) const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes, int op_version)
: DatasetBase(DatasetContext(ctx)), : DatasetBase(DatasetContext(ctx)),
input_(input), input_(input),
captured_func_(std::move(captured_func)), captured_func_(std::move(captured_func)),
cycle_length_(cycle_length), cycle_length_(cycle_length),
block_length_(block_length), block_length_(block_length),
sloppy_(sloppy), deterministic_(deterministic),
buffer_output_elements_(buffer_output_elements), buffer_output_elements_(buffer_output_elements),
prefetch_input_elements_(prefetch_input_elements), prefetch_input_elements_(prefetch_input_elements),
output_types_(output_types), output_types_(output_types),
@ -106,7 +109,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
traceme_metadata_( traceme_metadata_(
{{"block_length", strings::Printf("%lld", block_length)}, {{"block_length", strings::Printf("%lld", block_length)},
{"cycle_length", strings::Printf("%lld", cycle_length)}, {"cycle_length", strings::Printf("%lld", cycle_length)},
{"deterministic", sloppy ? "false" : "true"}}) { {"deterministic",
deterministic.IsDeterministic() || deterministic.IsDefault()
? "true"
: "false"}}),
op_version_(op_version) {
input_->Ref(); input_->Ref();
} }
@ -114,8 +121,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
std::unique_ptr<IteratorBase> MakeIteratorInternal( std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override { const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{ name_utils::IteratorPrefixParams params;
this, name_utils::IteratorPrefix(kDatasetType, prefix)}); params.op_version = op_version_;
bool deterministic =
deterministic_.IsDeterministic() || deterministic_.IsDefault();
return absl::make_unique<Iterator>(
Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
deterministic);
} }
const DataTypeVector& output_dtypes() const override { return output_types_; } const DataTypeVector& output_dtypes() const override { return output_types_; }
@ -125,7 +138,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
} }
string DebugString() const override { string DebugString() const override {
return name_utils::DatasetDebugString(kDatasetType); name_utils::DatasetDebugStringParams params;
params.op_version = op_version_;
return name_utils::DatasetDebugString(kDatasetType, params);
} }
Status CheckExternalState() const override { Status CheckExternalState() const override {
@ -137,39 +152,62 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
Node** output) const override { Node** output) const override {
std::vector<std::pair<size_t, Node*>> inputs;
std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
int input_index = 0;
Node* input_node; Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node; inputs.emplace_back(input_index++, input_node);
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
Node* sloppy_node;
TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node));
Node* buffer_output_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
Node* prefetch_input_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
std::vector<Node*> other_arguments; std::vector<Node*> other_arguments;
DataTypeVector other_arguments_types; DataTypeVector other_arguments_types;
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments, TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
&other_arguments_types)); &other_arguments_types));
list_inputs.emplace_back(input_index++, other_arguments);
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
inputs.emplace_back(input_index++, cycle_length_node);
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
inputs.emplace_back(input_index++, block_length_node);
if (op_version_ == 1) {
Node* sloppy_node;
TF_RETURN_IF_ERROR(
b->AddScalar(deterministic_.IsNondeterministic(), &sloppy_node));
inputs.emplace_back(input_index++, sloppy_node);
}
Node* buffer_output_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
inputs.emplace_back(input_index++, buffer_output_elements_node);
Node* prefetch_input_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
inputs.emplace_back(input_index++, prefetch_input_elements_node);
std::vector<std::pair<StringPiece, AttrValue>> attrs;
AttrValue f; AttrValue f;
b->BuildAttrValue(captured_func_->func(), &f); b->BuildAttrValue(captured_func_->func(), &f);
attrs.emplace_back(kFunc, f);
if (op_version_ == 2) {
AttrValue deterministic_attr;
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
attrs.emplace_back(kDeterministic, deterministic_attr);
}
AttrValue other_arguments_types_attr; AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
attrs.emplace_back(kTarguments, other_arguments_types_attr);
TF_RETURN_IF_ERROR(b->AddDataset( TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
this,
{{0, input_node},
{2, cycle_length_node},
{3, block_length_node},
{4, sloppy_node},
{5, buffer_output_elements_node},
{6, prefetch_input_elements_node}},
{{1, other_arguments}},
{{kFunc, f}, {kTarguments, other_arguments_types_attr}}, output));
return Status::OK(); return Status::OK();
} }
@ -226,8 +264,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// an element in `interleave_indices_` or `staging_indices_`. // an element in `interleave_indices_` or `staging_indices_`.
class Iterator : public DatasetIterator<Dataset> { class Iterator : public DatasetIterator<Dataset> {
public: public:
explicit Iterator(const Params& params) explicit Iterator(const Params& params, bool deterministic)
: DatasetIterator<Dataset>(params), : DatasetIterator<Dataset>(params),
deterministic_(deterministic),
workers_(dataset()->num_threads()), workers_(dataset()->num_threads()),
worker_thread_states_(dataset()->num_threads()) {} worker_thread_states_(dataset()->num_threads()) {}
@ -244,7 +283,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// It is implemented so that it matches the deterministic interleave // It is implemented so that it matches the deterministic interleave
// unless getting the next element would block and we are allowed to be // unless getting the next element would block and we are allowed to be
// sloppy. // nondeterministic.
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override { bool* end_of_sequence) override {
@ -252,8 +291,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
while (!cancelled_) { while (!cancelled_) {
// Wait for an item to become available, blocking if necessary. If we // Wait for an item to become available, blocking if necessary. If we
// are allowed to be sloppy, we can skip over input datasets that do // are allowed to be nondeterministic, we can skip over input datasets
// not have an item readily available. // that do not have an item readily available.
bool can_produce_elements = false; bool can_produce_elements = false;
bool must_wait_for_input = true; bool must_wait_for_input = true;
for (int64 i = 0; i < interleave_indices_.size(); ++i) { for (int64 i = 0; i < interleave_indices_.size(); ++i) {
@ -267,9 +306,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
if (!current_worker->outputs.empty()) { if (!current_worker->outputs.empty()) {
// We have an element! // We have an element!
next_index_ = index; next_index_ = index;
const bool element_acquired_sloppily = dataset()->sloppy_ && i > 1; const bool element_acquired_sloppily = !deterministic_ && i > 1;
if (!element_acquired_sloppily) { if (!element_acquired_sloppily) {
// If the element was acquired in the regular (non-sloppy) // If the element was acquired in the regular (deterministic)
// order, then advance the current block and cycle pointers to // order, then advance the current block and cycle pointers to
// the next element in the regular order. // the next element in the regular order.
block_count_++; block_count_++;
@ -286,7 +325,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
current_worker->outputs.pop_front(); current_worker->outputs.pop_front();
current_worker->cond_var.notify_one(); current_worker->cond_var.notify_one();
return s; return s;
} else if (current_worker->is_producing && !dataset()->sloppy_) { } else if (current_worker->is_producing && deterministic_) {
// current_worker.outputs.empty(), and we must wait for this // current_worker.outputs.empty(), and we must wait for this
// iterator. // iterator.
if (next_index_ != index) { if (next_index_ != index) {
@ -336,10 +375,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
if (must_wait_for_input) { if (must_wait_for_input) {
// Wait for elements to become available. // Wait for elements to become available.
RecordStop(ctx); RecordStop(ctx);
if (dataset()->sloppy_) { if (deterministic_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l); workers_[interleave_indices_[next_index_]].cond_var.wait(l);
} else {
any_element_available_cond_var_.wait(l);
} }
RecordStart(ctx); RecordStart(ctx);
} }
@ -542,7 +581,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// for the main thread to add arguments to `input`, or (2) waiting for // for the main thread to add arguments to `input`, or (2) waiting for
// the main thread to consume an element of `outputs`. The main thread // the main thread to consume an element of `outputs`. The main thread
// waits on cond_var if it is waiting for the worker thread to produce // waits on cond_var if it is waiting for the worker thread to produce
// an element into `outputs` (this implies sloppy_==false). // an element into `outputs` (this implies deterministic==true).
condition_variable cond_var; condition_variable cond_var;
inline bool MayHaveElements() const { inline bool MayHaveElements() const {
@ -754,10 +793,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// CHECKPOINT_MARKER_C // CHECKPOINT_MARKER_C
// Non-OK iterator creation status has been notified to the // Non-OK iterator creation status has been notified to the
// client. // client.
if (dataset()->sloppy_) { if (deterministic_) {
sloppy_cond_var_.notify_one();
} else {
workers_[thread_index].cond_var.notify_one(); workers_[thread_index].cond_var.notify_one();
} else {
any_element_available_cond_var_.notify_one();
} }
} else { } else {
bool end_of_sequence = false; bool end_of_sequence = false;
@ -818,10 +857,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
} }
worker_thread_states_[thread_index].output_elem.status = worker_thread_states_[thread_index].output_elem.status =
Status::OK(); Status::OK();
if (dataset()->sloppy_) { if (deterministic_) {
sloppy_cond_var_.notify_one();
} else {
workers_[thread_index].cond_var.notify_one(); workers_[thread_index].cond_var.notify_one();
} else {
any_element_available_cond_var_.notify_one();
} }
// CHECKPOINT_MARKER_E // CHECKPOINT_MARKER_E
// Output element or iterator status has been sent to the // Output element or iterator status has been sent to the
@ -1040,9 +1079,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// Mutex & condition variable to guard mutable iterator internals and // Mutex & condition variable to guard mutable iterator internals and
// coordinate among worker threads and client thread[s]. // coordinate among worker threads and client thread[s].
mutex mu_ ACQUIRED_BEFORE(ckpt_mu_); mutex mu_ ACQUIRED_BEFORE(ckpt_mu_);
// The main thread waits on this condition variable if running in sloppy // The main thread waits on this condition variable if running in
// mode and no values are available. // nondeterministic mode and no values are available.
condition_variable sloppy_cond_var_; condition_variable any_element_available_cond_var_;
// Whether outputs must be produced in deterministic order.
const bool deterministic_;
// Mutex used to wait for a consistent state while checkpointing. // Mutex used to wait for a consistent state while checkpointing.
// Only Save and Restore require an exclusive lock on this mutex. In // Only Save and Restore require an exclusive lock on this mutex. In
// other scenarios we just acquire a shared lock so the pipeline's // other scenarios we just acquire a shared lock so the pipeline's
@ -1087,21 +1128,29 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
const std::unique_ptr<CapturedFunction> captured_func_; const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_; const int64 cycle_length_;
const int64 block_length_; const int64 block_length_;
const bool sloppy_; const DeterminismPolicy deterministic_;
const int64 buffer_output_elements_; const int64 buffer_output_elements_;
const int64 prefetch_input_elements_; const int64 prefetch_input_elements_;
const DataTypeVector output_types_; const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_; const std::vector<PartialTensorShape> output_shapes_;
const TraceMeMetadata traceme_metadata_; const TraceMeMetadata traceme_metadata_;
const int op_version_;
}; };
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp( ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
OpKernelConstruction* ctx) OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) { : UnaryDatasetOpKernel(ctx),
op_version_(ctx->HasAttr(kDeterministic) ? 2 : 1) {
FunctionMetadata::Params params; FunctionMetadata::Params params;
params.is_multi_device_function = true; params.is_multi_device_function = true;
OP_REQUIRES_OK(ctx, OP_REQUIRES_OK(ctx,
FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_)); FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
if (op_version_ == 2) {
std::string deterministic;
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
OP_REQUIRES_OK(
ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
}
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
} }
@ -1119,8 +1168,17 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
OP_REQUIRES(ctx, block_length > 0, OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0")); errors::InvalidArgument("`block_length` must be > 0"));
bool sloppy = false; if (op_version_ == 1) {
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy)); bool sloppy = false;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
if (sloppy) {
deterministic_ =
DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
} else {
deterministic_ =
DeterminismPolicy(DeterminismPolicy::Type::kDeterministic);
}
}
int64 buffer_output_elements = 0; int64 buffer_output_elements = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements, OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
@ -1141,8 +1199,9 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
&captured_func)); &captured_func));
*output = new Dataset(ctx, input, std::move(captured_func), cycle_length, *output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
block_length, sloppy, buffer_output_elements, block_length, deterministic_, buffer_output_elements,
prefetch_input_elements, output_types_, output_shapes_); prefetch_input_elements, output_types_, output_shapes_,
op_version_);
} }
namespace { namespace {
@ -1151,9 +1210,13 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU), Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp); ParallelInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("LegacyParallelInterleaveDatasetV2").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset"); REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset"); REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("LegacyParallelInterleaveDatasetV2");
} // namespace } // namespace
} // namespace experimental } // namespace experimental

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
@ -27,11 +28,12 @@ namespace experimental {
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public: public:
static constexpr const char* const kDatasetType = "ParallelInterleave"; static constexpr const char* const kDatasetType = "LegacyParallelInterleave";
static constexpr const char* const kInputDataset = "input_dataset"; static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kOtherArguments = "other_arguments"; static constexpr const char* const kOtherArguments = "other_arguments";
static constexpr const char* const kCycleLength = "cycle_length"; static constexpr const char* const kCycleLength = "cycle_length";
static constexpr const char* const kBlockLength = "block_length"; static constexpr const char* const kBlockLength = "block_length";
static constexpr const char* const kDeterministic = "deterministic";
static constexpr const char* const kSloppy = "sloppy"; static constexpr const char* const kSloppy = "sloppy";
static constexpr const char* const kBufferOutputElements = static constexpr const char* const kBufferOutputElements =
"buffer_output_elements"; "buffer_output_elements";
@ -50,10 +52,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
private: private:
class Dataset; class Dataset;
const int op_version_;
std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr; std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
DataTypeVector output_types_; DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_; std::vector<PartialTensorShape> output_shapes_;
DeterminismPolicy deterministic_;
}; };
} // namespace experimental } // namespace experimental

View File

@ -20,33 +20,37 @@ namespace experimental {
namespace { namespace {
constexpr char kNodeName[] = "parallel_interleave_dataset"; constexpr char kNodeName[] = "parallel_interleave_dataset";
constexpr int kOpVersion = 2;
class ParallelInterleaveDatasetParams : public DatasetParams { class ParallelInterleaveDatasetParams : public DatasetParams {
public: public:
template <typename T> template <typename T>
ParallelInterleaveDatasetParams( ParallelInterleaveDatasetParams(
T input_dataset_params, std::vector<Tensor> other_arguments, T input_dataset_params, std::vector<Tensor> other_arguments,
int64 cycle_length, int64 block_length, bool sloppy, int64 cycle_length, int64 block_length, const std::string& deterministic,
int64 buffer_output_elements, int64 prefetch_input_elements, int64 buffer_output_elements, int64 prefetch_input_elements,
FunctionDefHelper::AttrValueWrapper func, FunctionDefHelper::AttrValueWrapper func,
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments, std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
DataTypeVector output_dtypes, const DataTypeVector& output_dtypes,
std::vector<PartialTensorShape> output_shapes, string node_name) const std::vector<PartialTensorShape>& output_shapes, string node_name)
: DatasetParams(std::move(output_dtypes), std::move(output_shapes), : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
std::move(node_name)), std::move(node_name)),
other_arguments_(std::move(other_arguments)), other_arguments_(std::move(other_arguments)),
cycle_length_(cycle_length), cycle_length_(cycle_length),
block_length_(block_length), block_length_(block_length),
sloppy_(sloppy), deterministic_(deterministic),
buffer_output_elements_(buffer_output_elements), buffer_output_elements_(buffer_output_elements),
prefetch_input_elements_(prefetch_input_elements), prefetch_input_elements_(prefetch_input_elements),
func_(std::move(func)), func_(std::move(func)),
func_lib_(std::move(func_lib)), func_lib_(std::move(func_lib)),
type_arguments_(std::move(type_arguments)) { type_arguments_(std::move(type_arguments)) {
input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params)); input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
iterator_prefix_ = op_version_ = kOpVersion;
name_utils::IteratorPrefix(input_dataset_params.dataset_type(), name_utils::IteratorPrefixParams params;
input_dataset_params.iterator_prefix()); params.op_version = op_version_;
iterator_prefix_ = name_utils::IteratorPrefix(
input_dataset_params.dataset_type(),
input_dataset_params.iterator_prefix(), params);
} }
std::vector<Tensor> GetInputTensors() const override { std::vector<Tensor> GetInputTensors() const override {
@ -55,7 +59,6 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
CreateTensor<int64>(TensorShape({}), {cycle_length_})); CreateTensor<int64>(TensorShape({}), {cycle_length_}));
input_tensors.emplace_back( input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {block_length_})); CreateTensor<int64>(TensorShape({}), {block_length_}));
input_tensors.emplace_back(CreateTensor<bool>(TensorShape({}), {sloppy_}));
input_tensors.emplace_back( input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {buffer_output_elements_})); CreateTensor<int64>(TensorShape({}), {buffer_output_elements_}));
input_tensors.emplace_back( input_tensors.emplace_back(
@ -71,7 +74,6 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
} }
input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength); input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength);
input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength); input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength);
input_names->emplace_back(ParallelInterleaveDatasetOp::kSloppy);
input_names->emplace_back( input_names->emplace_back(
ParallelInterleaveDatasetOp::kBufferOutputElements); ParallelInterleaveDatasetOp::kBufferOutputElements);
input_names->emplace_back( input_names->emplace_back(
@ -82,6 +84,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
Status GetAttributes(AttributeVector* attr_vector) const override { Status GetAttributes(AttributeVector* attr_vector) const override {
*attr_vector = { *attr_vector = {
{ParallelInterleaveDatasetOp::kFunc, func_}, {ParallelInterleaveDatasetOp::kFunc, func_},
{ParallelInterleaveDatasetOp::kDeterministic, deterministic_},
{ParallelInterleaveDatasetOp::kTarguments, type_arguments_}, {ParallelInterleaveDatasetOp::kTarguments, type_arguments_},
{ParallelInterleaveDatasetOp::kOutputShapes, output_shapes_}, {ParallelInterleaveDatasetOp::kOutputShapes, output_shapes_},
{ParallelInterleaveDatasetOp::kOutputTypes, output_dtypes_}}; {ParallelInterleaveDatasetOp::kOutputTypes, output_dtypes_}};
@ -98,7 +101,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
std::vector<Tensor> other_arguments_; std::vector<Tensor> other_arguments_;
int64 cycle_length_; int64 cycle_length_;
int64 block_length_; int64 block_length_;
bool sloppy_; std::string deterministic_;
int64 buffer_output_elements_; int64 buffer_output_elements_;
int64 prefetch_input_elements_; int64 prefetch_input_elements_;
FunctionDefHelper::AttrValueWrapper func_; FunctionDefHelper::AttrValueWrapper func_;
@ -117,7 +120,7 @@ FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc(
{TensorSliceDatasetOp::kOutputShapes, output_shapes}}); {TensorSliceDatasetOp::kOutputShapes, output_shapes}});
} }
// Test case 1: cycle_length = 1, block_length = 1, sloppy = false, // Test case 1: cycle_length = 1, block_length = 1, deterministic = true,
// buffer_output_elements = 1, prefetch_input_elements = 1. // buffer_output_elements = 1, prefetch_input_elements = 1.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -129,7 +132,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/1, /*cycle_length=*/1,
/*block_length=*/1, /*block_length=*/1,
/*sloppy=*/false, /*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1, /*buffer_output_elements=*/1,
/*prefetch_input_elements=*/1, /*prefetch_input_elements=*/1,
/*func=*/ /*func=*/
@ -143,7 +146,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// Test case 2: cycle_length = 2, block_length = 1, sloppy = false, // Test case 2: cycle_length = 2, block_length = 1, deterministic = true,
// buffer_output_elements = 1, prefetch_input_elements = 0. // buffer_output_elements = 1, prefetch_input_elements = 0.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -155,7 +158,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/2, /*cycle_length=*/2,
/*block_length=*/1, /*block_length=*/1,
/*sloppy=*/false, /*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1, /*buffer_output_elements=*/1,
/*prefetch_input_elements=*/0, /*prefetch_input_elements=*/0,
/*func=*/ /*func=*/
@ -169,7 +172,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// Test case 3: cycle_length = 3, block_length = 1, sloppy = true, // Test case 3: cycle_length = 3, block_length = 1, deterministic = false,
// buffer_output_elements = 3, prefetch_input_elements = 2. // buffer_output_elements = 3, prefetch_input_elements = 2.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -181,7 +184,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/3, /*cycle_length=*/3,
/*block_length=*/1, /*block_length=*/1,
/*sloppy=*/true, /*deterministic=*/DeterminismPolicy::kNondeterministic,
/*buffer_output_elements=*/3, /*buffer_output_elements=*/3,
/*prefetch_input_elements=*/2, /*prefetch_input_elements=*/2,
/*func=*/ /*func=*/
@ -195,7 +198,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// Test case 4: cycle_length = 5, block_length = 1, sloppy = true // Test case 4: cycle_length = 5, block_length = 1, deterministic = false
// buffer_output_elements = 1, prefetch_input_elements = 2. // buffer_output_elements = 1, prefetch_input_elements = 2.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -207,7 +210,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/5, /*cycle_length=*/5,
/*block_length=*/1, /*block_length=*/1,
/*sloppy=*/true, /*deterministic=*/DeterminismPolicy::kNondeterministic,
/*buffer_output_elements=*/1, /*buffer_output_elements=*/1,
/*prefetch_input_elements=*/2, /*prefetch_input_elements=*/2,
/*func=*/ /*func=*/
@ -221,7 +224,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
/*node_name=*/kNodeName); /*node_name=*/kNodeName);
} }
// Test case 5: cycle_length = 2, block_length = 2, sloppy = false // Test case 5: cycle_length = 2, block_length = 2, deterministic = true
// buffer_output_elements = 2, prefetch_input_elements = 2. // buffer_output_elements = 2, prefetch_input_elements = 2.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() { ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams( auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -233,7 +236,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/2, /*cycle_length=*/2,
/*block_length=*/2, /*block_length=*/2,
/*sloppy=*/false, /*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/2, /*buffer_output_elements=*/2,
/*prefetch_input_elements=*/2, /*prefetch_input_elements=*/2,
/*func=*/ /*func=*/
@ -256,7 +259,7 @@ ParallelInterleaveDatasetParams EmptyInputParams() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/2, /*cycle_length=*/2,
/*block_length=*/2, /*block_length=*/2,
/*sloppy=*/true, /*deterministic=*/DeterminismPolicy::kNondeterministic,
/*buffer_output_elements=*/2, /*buffer_output_elements=*/2,
/*prefetch_input_elements=*/2, /*prefetch_input_elements=*/2,
/*func=*/ /*func=*/
@ -280,7 +283,7 @@ ParallelInterleaveDatasetParams InvalidCycleLengthParams() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/0, /*cycle_length=*/0,
/*block_length=*/1, /*block_length=*/1,
/*sloppy=*/false, /*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1, /*buffer_output_elements=*/1,
/*prefetch_input_elements=*/1, /*prefetch_input_elements=*/1,
/*func=*/ /*func=*/
@ -304,7 +307,7 @@ ParallelInterleaveDatasetParams InvalidBlockLengthParams() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/1, /*cycle_length=*/1,
/*block_length=*/-1, /*block_length=*/-1,
/*sloppy=*/false, /*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1, /*buffer_output_elements=*/1,
/*prefetch_input_elements=*/1, /*prefetch_input_elements=*/1,
/*func=*/ /*func=*/
@ -328,7 +331,7 @@ ParallelInterleaveDatasetParams InvalidBufferOutputElementsParams() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/1, /*cycle_length=*/1,
/*block_length=*/1, /*block_length=*/1,
/*sloppy=*/false, /*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/0, /*buffer_output_elements=*/0,
/*prefetch_input_elements=*/1, /*prefetch_input_elements=*/1,
/*func=*/ /*func=*/
@ -352,7 +355,7 @@ ParallelInterleaveDatasetParams InvalidPrefetchInputElementsParams() {
/*other_arguments=*/{}, /*other_arguments=*/{},
/*cycle_length=*/1, /*cycle_length=*/1,
/*block_length=*/1, /*block_length=*/1,
/*sloppy=*/false, /*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1, /*buffer_output_elements=*/1,
/*prefetch_input_elements=*/-1, /*prefetch_input_elements=*/-1,
/*func=*/ /*func=*/
@ -412,8 +415,10 @@ TEST_F(ParallelInterleaveDatasetOpTest, DatasetNodeName) {
TEST_F(ParallelInterleaveDatasetOpTest, DatasetTypeString) { TEST_F(ParallelInterleaveDatasetOpTest, DatasetTypeString) {
auto dataset_params = ParallelInterleaveDatasetParams1(); auto dataset_params = ParallelInterleaveDatasetParams1();
TF_ASSERT_OK(Initialize(dataset_params)); TF_ASSERT_OK(Initialize(dataset_params));
name_utils::OpNameParams params;
params.op_version = dataset_params.op_version();
TF_ASSERT_OK(CheckDatasetTypeString( TF_ASSERT_OK(CheckDatasetTypeString(
name_utils::OpName(ParallelInterleaveDatasetOp::kDatasetType))); name_utils::OpName(ParallelInterleaveDatasetOp::kDatasetType, params)));
} }
TEST_F(ParallelInterleaveDatasetOpTest, DatasetOutputDtypes) { TEST_F(ParallelInterleaveDatasetOpTest, DatasetOutputDtypes) {
@ -461,9 +466,11 @@ TEST_F(ParallelInterleaveDatasetOpTest, IteratorOutputShapes) {
TEST_F(ParallelInterleaveDatasetOpTest, IteratorPrefix) { TEST_F(ParallelInterleaveDatasetOpTest, IteratorPrefix) {
auto dataset_params = ParallelInterleaveDatasetParams1(); auto dataset_params = ParallelInterleaveDatasetParams1();
TF_ASSERT_OK(Initialize(dataset_params)); TF_ASSERT_OK(Initialize(dataset_params));
name_utils::IteratorPrefixParams params;
params.op_version = dataset_params.op_version();
TF_ASSERT_OK(CheckIteratorPrefix( TF_ASSERT_OK(CheckIteratorPrefix(
name_utils::IteratorPrefix(ParallelInterleaveDatasetOp::kDatasetType, name_utils::IteratorPrefix(ParallelInterleaveDatasetOp::kDatasetType,
dataset_params.iterator_prefix()))); dataset_params.iterator_prefix(), params)));
} }
std::vector<IteratorSaveAndRestoreTestCase<ParallelInterleaveDatasetParams>> std::vector<IteratorSaveAndRestoreTestCase<ParallelInterleaveDatasetParams>>

View File

@ -561,6 +561,26 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1") .Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape); .SetShapeFn(shape_inference::ScalarShape);
// This is the V2 of ParallelInterleaveDataset, renamed to differentiate it
// from the non-experimental ParallelInterleaveDataset op.
REGISTER_OP("LegacyParallelInterleaveDatasetV2")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Input("cycle_length: int64")
.Input("block_length: int64")
.Input("buffer_output_elements: int64")
.Input("prefetch_input_elements: int64")
.Output("handle: variant")
.Attr("f: func")
// "true", "false", or "default".
.Attr("deterministic: string = 'default'")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
// This op is no longer used. We keep it so that we can read graphs written by
// old versions of TensorFlow.
REGISTER_OP("ExperimentalParallelInterleaveDataset") REGISTER_OP("ExperimentalParallelInterleaveDataset")
.Input("input_dataset: variant") .Input("input_dataset: variant")
.Input("other_arguments: Targuments") .Input("other_arguments: Targuments")

View File

@ -466,6 +466,7 @@ tf_py_test(
"//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor", "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/experimental/ops:interleave_ops", "//tensorflow/python/data/experimental/ops:interleave_ops",
"//tensorflow/python/data/experimental/ops:testing",
"//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"@six_archive//:six", "@six_archive//:six",

View File

@ -26,6 +26,7 @@ import numpy as np
from six.moves import zip_longest from six.moves import zip_longest
from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations from tensorflow.python.framework import combinations
@ -729,6 +730,40 @@ class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
results.append(elements) results.append(elements)
self.assertAllEqual(results[0], results[1]) self.assertAllEqual(results[0], results[1])
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
sloppy=[None, True, False], global_determinism=[True, False])))
def testDeterminismConfiguration(self, sloppy, global_determinism):
if sloppy is None:
expect_determinism = global_determinism
else:
expect_determinism = not sloppy
elements = list(range(1000))
def dataset_fn(delay_ms):
def interleave_fn(x):
ds = dataset_ops.Dataset.from_tensors(x)
if math_ops.equal(x, 0):
ds = ds.apply(testing.sleep(delay_ms * 1000))
else:
ds = ds.apply(testing.sleep(0))
return ds
dataset = dataset_ops.Dataset.from_tensor_slices(elements)
dataset = dataset.apply(
interleave_ops.parallel_interleave(
interleave_fn, cycle_length=10, sloppy=sloppy))
opts = dataset_ops.Options()
opts.experimental_deterministic = global_determinism
dataset = dataset.with_options(opts)
return dataset
self.checkDeterminism(dataset_fn, expect_determinism, elements)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -76,9 +76,11 @@ def parallel_interleave(map_func,
cycle_length: The number of input `Dataset`s to interleave from in parallel. cycle_length: The number of input `Dataset`s to interleave from in parallel.
block_length: The number of consecutive elements to pull from an input block_length: The number of consecutive elements to pull from an input
`Dataset` before advancing to the next input `Dataset`. `Dataset` before advancing to the next input `Dataset`.
sloppy: If false, elements are produced in deterministic order. Otherwise, sloppy: A boolean controlling whether determinism should be traded for
the implementation is allowed, for the sake of expediency, to produce performance by allowing elements to be produced out of order. If
elements in a non-deterministic order. `sloppy` is `None`, the `tf.data.Options.experimental_deterministic`
dataset option (`True` by default) is used to decide whether to enforce a
deterministic order.
buffer_output_elements: The number of elements each iterator being buffer_output_elements: The number of elements each iterator being
interleaved should buffer (similar to the `.prefetch()` transformation for interleaved should buffer (similar to the `.prefetch()` transformation for
each interleaved iterator). each interleaved iterator).

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -248,8 +249,9 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
cycle_length, dtype=dtypes.int64, name="cycle_length") cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor( self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length") block_length, dtype=dtypes.int64, name="block_length")
self._sloppy = ops.convert_to_tensor( if sloppy is not None:
sloppy, dtype=dtypes.bool, name="sloppy") self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
self._buffer_output_elements = convert.optional_param_to_tensor( self._buffer_output_elements = convert.optional_param_to_tensor(
"buffer_output_elements", "buffer_output_elements",
buffer_output_elements, buffer_output_elements,
@ -258,16 +260,34 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
"prefetch_input_elements", "prefetch_input_elements",
prefetch_input_elements, prefetch_input_elements,
argument_default=2 * cycle_length) argument_default=2 * cycle_length)
variant_tensor = ged_ops.parallel_interleave_dataset( if sloppy is None or compat.forward_compatible(2020, 3, 6):
self._input_dataset._variant_tensor, # pylint: disable=protected-access if sloppy is None:
self._map_func.function.captured_inputs, self._deterministic = "default"
self._cycle_length, elif sloppy:
self._block_length, self._deterministic = "false"
self._sloppy, else:
self._buffer_output_elements, self._deterministic = "true"
self._prefetch_input_elements, variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
f=self._map_func.function, self._input_dataset._variant_tensor, # pylint: disable=protected-access
**self._flat_structure) self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func.function,
deterministic=self._deterministic,
**self._flat_structure)
else:
variant_tensor = ged_ops.parallel_interleave_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
self._sloppy,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func.function,
**self._flat_structure)
super(ParallelInterleaveDataset, self).__init__(input_dataset, super(ParallelInterleaveDataset, self).__init__(input_dataset,
variant_tensor) variant_tensor)

View File

@ -1964,6 +1964,10 @@ tf_module {
name: "LeftShift" name: "LeftShift"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "LegacyParallelInterleaveDatasetV2"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method { member_method {
name: "Less" name: "Less"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1964,6 +1964,10 @@ tf_module {
name: "LeftShift" name: "LeftShift"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "LegacyParallelInterleaveDatasetV2"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method { member_method {
name: "Less" name: "Less"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "