Support None option for experimental_interleave sloppiness.

If sloppy=None, the transform will use the experimental_deterministic option to
determine whether to use sloppy behavior.

PiperOrigin-RevId: 295252219
Change-Id: I93b4ac182ea4edd8666364827d226593a3d3c7bd
This commit is contained in:
Andrew Audibert 2020-02-14 16:28:36 -08:00 committed by TensorFlower Gardener
parent 33f00e722e
commit 745ed4714d
13 changed files with 284 additions and 101 deletions

View File

@ -0,0 +1,22 @@
op {
graph_op_name: "LegacyParallelInterleaveDatasetV2"
visibility: HIDDEN
attr {
name: "f"
description: <<END
A function mapping elements of `input_dataset`, concatenated with
`other_arguments`, to a Dataset variant that contains elements matching
`output_types` and `output_shapes`.
END
}
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
description: <<END
The resulting dataset is similar to the `InterleaveDataset`, with the exception
that if retrieving the next value from a dataset would cause the requester to
block, it will skip that input dataset. This dataset is especially useful
when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
allows the training step to proceed so long as some data is available.
!! WARNING !! This dataset is not deterministic!
END
}

View File

@ -90,10 +90,11 @@ constexpr std::array<const char*, 29> kPassThroughOps = {
};
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
constexpr std::array<const char*, 7> kFuncDatasetOps = {
constexpr std::array<const char*, 8> kFuncDatasetOps = {
"ExperimentalParallelInterleaveDataset",
"FlatMapDataset",
"InterleaveDataset",
"LegacyParallelInterleaveDatasetV2",
"ParallelInterleaveDataset",
"ParallelInterleaveDatasetV2",
"ParallelInterleaveDatasetV3",

View File

@ -33,10 +33,10 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
};
constexpr std::array<const char*, 4> kDeterministicAttrOps = {
"LegacyParallelInterleaveDatasetV2",
"ParallelInterleaveDatasetV3",
"ParallelInterleaveDatasetV4",
"ParallelMapDatasetV2",
"ParseExampleDatasetV2",
};
} // anonymous namespace

View File

@ -46,6 +46,8 @@ namespace experimental {
ParallelInterleaveDatasetOp::kCycleLength;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kBlockLength;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kDeterministic;
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
/* static */ constexpr const char* const
ParallelInterleaveDatasetOp::kBufferOutputElements;
@ -90,15 +92,16 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, bool sloppy, int64 buffer_output_elements,
int64 prefetch_input_elements, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
int64 block_length, DeterminismPolicy deterministic,
int64 buffer_output_elements, int64 prefetch_input_elements,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes, int op_version)
: DatasetBase(DatasetContext(ctx)),
input_(input),
captured_func_(std::move(captured_func)),
cycle_length_(cycle_length),
block_length_(block_length),
sloppy_(sloppy),
deterministic_(deterministic),
buffer_output_elements_(buffer_output_elements),
prefetch_input_elements_(prefetch_input_elements),
output_types_(output_types),
@ -106,7 +109,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
traceme_metadata_(
{{"block_length", strings::Printf("%lld", block_length)},
{"cycle_length", strings::Printf("%lld", cycle_length)},
{"deterministic", sloppy ? "false" : "true"}}) {
{"deterministic",
deterministic.IsDeterministic() || deterministic.IsDefault()
? "true"
: "false"}}),
op_version_(op_version) {
input_->Ref();
}
@ -114,8 +121,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
name_utils::IteratorPrefixParams params;
params.op_version = op_version_;
bool deterministic =
deterministic_.IsDeterministic() || deterministic_.IsDefault();
return absl::make_unique<Iterator>(
Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
deterministic);
}
const DataTypeVector& output_dtypes() const override { return output_types_; }
@ -125,7 +138,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
}
string DebugString() const override {
return name_utils::DatasetDebugString(kDatasetType);
name_utils::DatasetDebugStringParams params;
params.op_version = op_version_;
return name_utils::DatasetDebugString(kDatasetType, params);
}
Status CheckExternalState() const override {
@ -137,39 +152,62 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
std::vector<std::pair<size_t, Node*>> inputs;
std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
int input_index = 0;
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
Node* sloppy_node;
TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node));
Node* buffer_output_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
Node* prefetch_input_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
inputs.emplace_back(input_index++, input_node);
std::vector<Node*> other_arguments;
DataTypeVector other_arguments_types;
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
&other_arguments_types));
list_inputs.emplace_back(input_index++, other_arguments);
Node* cycle_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
inputs.emplace_back(input_index++, cycle_length_node);
Node* block_length_node;
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
inputs.emplace_back(input_index++, block_length_node);
if (op_version_ == 1) {
Node* sloppy_node;
TF_RETURN_IF_ERROR(
b->AddScalar(deterministic_.IsNondeterministic(), &sloppy_node));
inputs.emplace_back(input_index++, sloppy_node);
}
Node* buffer_output_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
inputs.emplace_back(input_index++, buffer_output_elements_node);
Node* prefetch_input_elements_node;
TF_RETURN_IF_ERROR(
b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
inputs.emplace_back(input_index++, prefetch_input_elements_node);
std::vector<std::pair<StringPiece, AttrValue>> attrs;
AttrValue f;
b->BuildAttrValue(captured_func_->func(), &f);
attrs.emplace_back(kFunc, f);
if (op_version_ == 2) {
AttrValue deterministic_attr;
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
attrs.emplace_back(kDeterministic, deterministic_attr);
}
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
attrs.emplace_back(kTarguments, other_arguments_types_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this,
{{0, input_node},
{2, cycle_length_node},
{3, block_length_node},
{4, sloppy_node},
{5, buffer_output_elements_node},
{6, prefetch_input_elements_node}},
{{1, other_arguments}},
{{kFunc, f}, {kTarguments, other_arguments_types_attr}}, output));
TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
return Status::OK();
}
@ -226,8 +264,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// an element in `interleave_indices_` or `staging_indices_`.
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
explicit Iterator(const Params& params, bool deterministic)
: DatasetIterator<Dataset>(params),
deterministic_(deterministic),
workers_(dataset()->num_threads()),
worker_thread_states_(dataset()->num_threads()) {}
@ -244,7 +283,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// It is implemented so that it matches the deterministic interleave
// unless getting the next element would block and we are allowed to be
// sloppy.
// nondeterministic.
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
@ -252,8 +291,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
while (!cancelled_) {
// Wait for an item to become available, blocking if necessary. If we
// are allowed to be sloppy, we can skip over input datasets that do
// not have an item readily available.
// are allowed to be nondeterministic, we can skip over input datasets
// that do not have an item readily available.
bool can_produce_elements = false;
bool must_wait_for_input = true;
for (int64 i = 0; i < interleave_indices_.size(); ++i) {
@ -267,9 +306,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
if (!current_worker->outputs.empty()) {
// We have an element!
next_index_ = index;
const bool element_acquired_sloppily = dataset()->sloppy_ && i > 1;
const bool element_acquired_sloppily = !deterministic_ && i > 1;
if (!element_acquired_sloppily) {
// If the element was acquired in the regular (non-sloppy)
// If the element was acquired in the regular (deterministic)
// order, then advance the current block and cycle pointers to
// the next element in the regular order.
block_count_++;
@ -286,7 +325,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
current_worker->outputs.pop_front();
current_worker->cond_var.notify_one();
return s;
} else if (current_worker->is_producing && !dataset()->sloppy_) {
} else if (current_worker->is_producing && deterministic_) {
// current_worker.outputs.empty(), and we must wait for this
// iterator.
if (next_index_ != index) {
@ -336,10 +375,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
if (must_wait_for_input) {
// Wait for elements to become available.
RecordStop(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
if (deterministic_) {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
} else {
any_element_available_cond_var_.wait(l);
}
RecordStart(ctx);
}
@ -542,7 +581,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// for the main thread to add arguments to `input`, or (2) waiting for
// the main thread to consume an element of `outputs`. The main thread
// waits on cond_var if it is waiting for the worker thread to produce
// an element into `outputs` (this implies sloppy_==false).
// an element into `outputs` (this implies deterministic==true).
condition_variable cond_var;
inline bool MayHaveElements() const {
@ -754,10 +793,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// CHECKPOINT_MARKER_C
// Non-OK iterator creation status has been notified to the
// client.
if (dataset()->sloppy_) {
sloppy_cond_var_.notify_one();
} else {
if (deterministic_) {
workers_[thread_index].cond_var.notify_one();
} else {
any_element_available_cond_var_.notify_one();
}
} else {
bool end_of_sequence = false;
@ -818,10 +857,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
}
worker_thread_states_[thread_index].output_elem.status =
Status::OK();
if (dataset()->sloppy_) {
sloppy_cond_var_.notify_one();
} else {
if (deterministic_) {
workers_[thread_index].cond_var.notify_one();
} else {
any_element_available_cond_var_.notify_one();
}
// CHECKPOINT_MARKER_E
// Output element or iterator status has been sent to the
@ -1040,9 +1079,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// Mutex & condition variable to guard mutable iterator internals and
// coordinate among worker threads and client thread[s].
mutex mu_ ACQUIRED_BEFORE(ckpt_mu_);
// The main thread waits on this condition variable if running in sloppy
// mode and no values are available.
condition_variable sloppy_cond_var_;
// The main thread waits on this condition variable if running in
// nondeterministic mode and no values are available.
condition_variable any_element_available_cond_var_;
// Whether outputs must be produced in deterministic order.
const bool deterministic_;
// Mutex used to wait for a consistent state while checkpointing.
// Only Save and Restore require an exclusive lock on this mutex. In
// other scenarios we just acquire a shared lock so the pipeline's
@ -1087,21 +1128,29 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
const std::unique_ptr<CapturedFunction> captured_func_;
const int64 cycle_length_;
const int64 block_length_;
const bool sloppy_;
const DeterminismPolicy deterministic_;
const int64 buffer_output_elements_;
const int64 prefetch_input_elements_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
const TraceMeMetadata traceme_metadata_;
const int op_version_;
};
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
: UnaryDatasetOpKernel(ctx),
op_version_(ctx->HasAttr(kDeterministic) ? 2 : 1) {
FunctionMetadata::Params params;
params.is_multi_device_function = true;
OP_REQUIRES_OK(ctx,
FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
if (op_version_ == 2) {
std::string deterministic;
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
OP_REQUIRES_OK(
ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
}
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
}
@ -1119,8 +1168,17 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
OP_REQUIRES(ctx, block_length > 0,
errors::InvalidArgument("`block_length` must be > 0"));
bool sloppy = false;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
if (op_version_ == 1) {
bool sloppy = false;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
if (sloppy) {
deterministic_ =
DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
} else {
deterministic_ =
DeterminismPolicy(DeterminismPolicy::Type::kDeterministic);
}
}
int64 buffer_output_elements = 0;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
@ -1141,8 +1199,9 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
&captured_func));
*output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
block_length, sloppy, buffer_output_elements,
prefetch_input_elements, output_types_, output_shapes_);
block_length, deterministic_, buffer_output_elements,
prefetch_input_elements, output_types_, output_shapes_,
op_version_);
}
namespace {
@ -1151,9 +1210,13 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(
Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("LegacyParallelInterleaveDatasetV2").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
REGISTER_INPUT_COLOCATION_EXEMPTION("LegacyParallelInterleaveDatasetV2");
} // namespace
} // namespace experimental

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
namespace tensorflow {
namespace data {
@ -27,11 +28,12 @@ namespace experimental {
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "ParallelInterleave";
static constexpr const char* const kDatasetType = "LegacyParallelInterleave";
static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kOtherArguments = "other_arguments";
static constexpr const char* const kCycleLength = "cycle_length";
static constexpr const char* const kBlockLength = "block_length";
static constexpr const char* const kDeterministic = "deterministic";
static constexpr const char* const kSloppy = "sloppy";
static constexpr const char* const kBufferOutputElements =
"buffer_output_elements";
@ -50,10 +52,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
private:
class Dataset;
const int op_version_;
std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
DeterminismPolicy deterministic_;
};
} // namespace experimental

View File

@ -20,33 +20,37 @@ namespace experimental {
namespace {
constexpr char kNodeName[] = "parallel_interleave_dataset";
constexpr int kOpVersion = 2;
class ParallelInterleaveDatasetParams : public DatasetParams {
public:
template <typename T>
ParallelInterleaveDatasetParams(
T input_dataset_params, std::vector<Tensor> other_arguments,
int64 cycle_length, int64 block_length, bool sloppy,
int64 cycle_length, int64 block_length, const std::string& deterministic,
int64 buffer_output_elements, int64 prefetch_input_elements,
FunctionDefHelper::AttrValueWrapper func,
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
DataTypeVector output_dtypes,
std::vector<PartialTensorShape> output_shapes, string node_name)
const DataTypeVector& output_dtypes,
const std::vector<PartialTensorShape>& output_shapes, string node_name)
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
std::move(node_name)),
other_arguments_(std::move(other_arguments)),
cycle_length_(cycle_length),
block_length_(block_length),
sloppy_(sloppy),
deterministic_(deterministic),
buffer_output_elements_(buffer_output_elements),
prefetch_input_elements_(prefetch_input_elements),
func_(std::move(func)),
func_lib_(std::move(func_lib)),
type_arguments_(std::move(type_arguments)) {
input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
iterator_prefix_ =
name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
input_dataset_params.iterator_prefix());
op_version_ = kOpVersion;
name_utils::IteratorPrefixParams params;
params.op_version = op_version_;
iterator_prefix_ = name_utils::IteratorPrefix(
input_dataset_params.dataset_type(),
input_dataset_params.iterator_prefix(), params);
}
std::vector<Tensor> GetInputTensors() const override {
@ -55,7 +59,6 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
CreateTensor<int64>(TensorShape({}), {cycle_length_}));
input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {block_length_}));
input_tensors.emplace_back(CreateTensor<bool>(TensorShape({}), {sloppy_}));
input_tensors.emplace_back(
CreateTensor<int64>(TensorShape({}), {buffer_output_elements_}));
input_tensors.emplace_back(
@ -71,7 +74,6 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
}
input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength);
input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength);
input_names->emplace_back(ParallelInterleaveDatasetOp::kSloppy);
input_names->emplace_back(
ParallelInterleaveDatasetOp::kBufferOutputElements);
input_names->emplace_back(
@ -82,6 +84,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
Status GetAttributes(AttributeVector* attr_vector) const override {
*attr_vector = {
{ParallelInterleaveDatasetOp::kFunc, func_},
{ParallelInterleaveDatasetOp::kDeterministic, deterministic_},
{ParallelInterleaveDatasetOp::kTarguments, type_arguments_},
{ParallelInterleaveDatasetOp::kOutputShapes, output_shapes_},
{ParallelInterleaveDatasetOp::kOutputTypes, output_dtypes_}};
@ -98,7 +101,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
std::vector<Tensor> other_arguments_;
int64 cycle_length_;
int64 block_length_;
bool sloppy_;
std::string deterministic_;
int64 buffer_output_elements_;
int64 prefetch_input_elements_;
FunctionDefHelper::AttrValueWrapper func_;
@ -117,7 +120,7 @@ FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc(
{TensorSliceDatasetOp::kOutputShapes, output_shapes}});
}
// Test case 1: cycle_length = 1, block_length = 1, sloppy = false,
// Test case 1: cycle_length = 1, block_length = 1, deterministic = true,
// buffer_output_elements = 1, prefetch_input_elements = 1.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -129,7 +132,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/1,
/*sloppy=*/false,
/*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1,
/*prefetch_input_elements=*/1,
/*func=*/
@ -143,7 +146,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
/*node_name=*/kNodeName);
}
// Test case 2: cycle_length = 2, block_length = 1, sloppy = false,
// Test case 2: cycle_length = 2, block_length = 1, deterministic = true,
// buffer_output_elements = 1, prefetch_input_elements = 0.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -155,7 +158,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
/*other_arguments=*/{},
/*cycle_length=*/2,
/*block_length=*/1,
/*sloppy=*/false,
/*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1,
/*prefetch_input_elements=*/0,
/*func=*/
@ -169,7 +172,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
/*node_name=*/kNodeName);
}
// Test case 3: cycle_length = 3, block_length = 1, sloppy = true,
// Test case 3: cycle_length = 3, block_length = 1, deterministic = false,
// buffer_output_elements = 3, prefetch_input_elements = 2.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -181,7 +184,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
/*other_arguments=*/{},
/*cycle_length=*/3,
/*block_length=*/1,
/*sloppy=*/true,
/*deterministic=*/DeterminismPolicy::kNondeterministic,
/*buffer_output_elements=*/3,
/*prefetch_input_elements=*/2,
/*func=*/
@ -195,7 +198,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
/*node_name=*/kNodeName);
}
// Test case 4: cycle_length = 5, block_length = 1, sloppy = true
// Test case 4: cycle_length = 5, block_length = 1, deterministic = false
// buffer_output_elements = 1, prefetch_input_elements = 2.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -207,7 +210,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
/*other_arguments=*/{},
/*cycle_length=*/5,
/*block_length=*/1,
/*sloppy=*/true,
/*deterministic=*/DeterminismPolicy::kNondeterministic,
/*buffer_output_elements=*/1,
/*prefetch_input_elements=*/2,
/*func=*/
@ -221,7 +224,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
/*node_name=*/kNodeName);
}
// Test case 5: cycle_length = 2, block_length = 2, sloppy = false
// Test case 5: cycle_length = 2, block_length = 2, deterministic = true
// buffer_output_elements = 2, prefetch_input_elements = 2.
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
@ -233,7 +236,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
/*other_arguments=*/{},
/*cycle_length=*/2,
/*block_length=*/2,
/*sloppy=*/false,
/*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/2,
/*prefetch_input_elements=*/2,
/*func=*/
@ -256,7 +259,7 @@ ParallelInterleaveDatasetParams EmptyInputParams() {
/*other_arguments=*/{},
/*cycle_length=*/2,
/*block_length=*/2,
/*sloppy=*/true,
/*deterministic=*/DeterminismPolicy::kNondeterministic,
/*buffer_output_elements=*/2,
/*prefetch_input_elements=*/2,
/*func=*/
@ -280,7 +283,7 @@ ParallelInterleaveDatasetParams InvalidCycleLengthParams() {
/*other_arguments=*/{},
/*cycle_length=*/0,
/*block_length=*/1,
/*sloppy=*/false,
/*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1,
/*prefetch_input_elements=*/1,
/*func=*/
@ -304,7 +307,7 @@ ParallelInterleaveDatasetParams InvalidBlockLengthParams() {
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/-1,
/*sloppy=*/false,
/*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1,
/*prefetch_input_elements=*/1,
/*func=*/
@ -328,7 +331,7 @@ ParallelInterleaveDatasetParams InvalidBufferOutputElementsParams() {
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/1,
/*sloppy=*/false,
/*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/0,
/*prefetch_input_elements=*/1,
/*func=*/
@ -352,7 +355,7 @@ ParallelInterleaveDatasetParams InvalidPrefetchInputElementsParams() {
/*other_arguments=*/{},
/*cycle_length=*/1,
/*block_length=*/1,
/*sloppy=*/false,
/*deterministic=*/DeterminismPolicy::kDeterministic,
/*buffer_output_elements=*/1,
/*prefetch_input_elements=*/-1,
/*func=*/
@ -412,8 +415,10 @@ TEST_F(ParallelInterleaveDatasetOpTest, DatasetNodeName) {
TEST_F(ParallelInterleaveDatasetOpTest, DatasetTypeString) {
auto dataset_params = ParallelInterleaveDatasetParams1();
TF_ASSERT_OK(Initialize(dataset_params));
name_utils::OpNameParams params;
params.op_version = dataset_params.op_version();
TF_ASSERT_OK(CheckDatasetTypeString(
name_utils::OpName(ParallelInterleaveDatasetOp::kDatasetType)));
name_utils::OpName(ParallelInterleaveDatasetOp::kDatasetType, params)));
}
TEST_F(ParallelInterleaveDatasetOpTest, DatasetOutputDtypes) {
@ -461,9 +466,11 @@ TEST_F(ParallelInterleaveDatasetOpTest, IteratorOutputShapes) {
TEST_F(ParallelInterleaveDatasetOpTest, IteratorPrefix) {
auto dataset_params = ParallelInterleaveDatasetParams1();
TF_ASSERT_OK(Initialize(dataset_params));
name_utils::IteratorPrefixParams params;
params.op_version = dataset_params.op_version();
TF_ASSERT_OK(CheckIteratorPrefix(
name_utils::IteratorPrefix(ParallelInterleaveDatasetOp::kDatasetType,
dataset_params.iterator_prefix())));
dataset_params.iterator_prefix(), params)));
}
std::vector<IteratorSaveAndRestoreTestCase<ParallelInterleaveDatasetParams>>

View File

@ -561,6 +561,26 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
// This is the V2 of ParallelInterleaveDataset, renamed to differentiate it
// from the non-experimental ParallelInterleaveDataset op.
REGISTER_OP("LegacyParallelInterleaveDatasetV2")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Input("cycle_length: int64")
.Input("block_length: int64")
.Input("buffer_output_elements: int64")
.Input("prefetch_input_elements: int64")
.Output("handle: variant")
.Attr("f: func")
// "true", "false", or "default".
.Attr("deterministic: string = 'default'")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
// This op is no longer used. We keep it so that we can read graphs written by
// old versions of TensorFlow.
REGISTER_OP("ExperimentalParallelInterleaveDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")

View File

@ -466,6 +466,7 @@ tf_py_test(
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/experimental/ops:interleave_ops",
"//tensorflow/python/data/experimental/ops:testing",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@six_archive//:six",

View File

@ -26,6 +26,7 @@ import numpy as np
from six.moves import zip_longest
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
@ -729,6 +730,40 @@ class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
results.append(elements)
self.assertAllEqual(results[0], results[1])
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
sloppy=[None, True, False], global_determinism=[True, False])))
def testDeterminismConfiguration(self, sloppy, global_determinism):
if sloppy is None:
expect_determinism = global_determinism
else:
expect_determinism = not sloppy
elements = list(range(1000))
def dataset_fn(delay_ms):
def interleave_fn(x):
ds = dataset_ops.Dataset.from_tensors(x)
if math_ops.equal(x, 0):
ds = ds.apply(testing.sleep(delay_ms * 1000))
else:
ds = ds.apply(testing.sleep(0))
return ds
dataset = dataset_ops.Dataset.from_tensor_slices(elements)
dataset = dataset.apply(
interleave_ops.parallel_interleave(
interleave_fn, cycle_length=10, sloppy=sloppy))
opts = dataset_ops.Options()
opts.experimental_deterministic = global_determinism
dataset = dataset.with_options(opts)
return dataset
self.checkDeterminism(dataset_fn, expect_determinism, elements)
if __name__ == "__main__":
test.main()

View File

@ -76,9 +76,11 @@ def parallel_interleave(map_func,
cycle_length: The number of input `Dataset`s to interleave from in parallel.
block_length: The number of consecutive elements to pull from an input
`Dataset` before advancing to the next input `Dataset`.
sloppy: If false, elements are produced in deterministic order. Otherwise,
the implementation is allowed, for the sake of expediency, to produce
elements in a non-deterministic order.
sloppy: A boolean controlling whether determinism should be traded for
performance by allowing elements to be produced out of order. If
`sloppy` is `None`, the `tf.data.Options.experimental_deterministic`
dataset option (`True` by default) is used to decide whether to enforce a
deterministic order.
buffer_output_elements: The number of elements each iterator being
interleaved should buffer (similar to the `.prefetch()` transformation for
each interleaved iterator).

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes
@ -248,8 +249,9 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
cycle_length, dtype=dtypes.int64, name="cycle_length")
self._block_length = ops.convert_to_tensor(
block_length, dtype=dtypes.int64, name="block_length")
self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
if sloppy is not None:
self._sloppy = ops.convert_to_tensor(
sloppy, dtype=dtypes.bool, name="sloppy")
self._buffer_output_elements = convert.optional_param_to_tensor(
"buffer_output_elements",
buffer_output_elements,
@ -258,16 +260,34 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
"prefetch_input_elements",
prefetch_input_elements,
argument_default=2 * cycle_length)
variant_tensor = ged_ops.parallel_interleave_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
self._sloppy,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func.function,
**self._flat_structure)
if sloppy is None or compat.forward_compatible(2020, 3, 6):
if sloppy is None:
self._deterministic = "default"
elif sloppy:
self._deterministic = "false"
else:
self._deterministic = "true"
variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func.function,
deterministic=self._deterministic,
**self._flat_structure)
else:
variant_tensor = ged_ops.parallel_interleave_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
self._cycle_length,
self._block_length,
self._sloppy,
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func.function,
**self._flat_structure)
super(ParallelInterleaveDataset, self).__init__(input_dataset,
variant_tensor)

View File

@ -1964,6 +1964,10 @@ tf_module {
name: "LeftShift"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "LegacyParallelInterleaveDatasetV2"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method {
name: "Less"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1964,6 +1964,10 @@ tf_module {
name: "LeftShift"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "LegacyParallelInterleaveDatasetV2"
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
}
member_method {
name: "Less"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "