Support None option for experimental_interleave sloppiness.
If sloppy=None, the transform will use the experimental_deterministic option to determine whether to use sloppy behavior. PiperOrigin-RevId: 295252219 Change-Id: I93b4ac182ea4edd8666364827d226593a3d3c7bd
This commit is contained in:
parent
33f00e722e
commit
745ed4714d
@ -0,0 +1,22 @@
|
||||
op {
|
||||
graph_op_name: "LegacyParallelInterleaveDatasetV2"
|
||||
visibility: HIDDEN
|
||||
attr {
|
||||
name: "f"
|
||||
description: <<END
|
||||
A function mapping elements of `input_dataset`, concatenated with
|
||||
`other_arguments`, to a Dataset variant that contains elements matching
|
||||
`output_types` and `output_shapes`.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
|
||||
description: <<END
|
||||
The resulting dataset is similar to the `InterleaveDataset`, with the exception
|
||||
that if retrieving the next value from a dataset would cause the requester to
|
||||
block, it will skip that input dataset. This dataset is especially useful
|
||||
when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
|
||||
allows the training step to proceed so long as some data is available.
|
||||
|
||||
!! WARNING !! This dataset is not deterministic!
|
||||
END
|
||||
}
|
@ -90,10 +90,11 @@ constexpr std::array<const char*, 29> kPassThroughOps = {
|
||||
};
|
||||
|
||||
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
|
||||
constexpr std::array<const char*, 7> kFuncDatasetOps = {
|
||||
constexpr std::array<const char*, 8> kFuncDatasetOps = {
|
||||
"ExperimentalParallelInterleaveDataset",
|
||||
"FlatMapDataset",
|
||||
"InterleaveDataset",
|
||||
"LegacyParallelInterleaveDatasetV2",
|
||||
"ParallelInterleaveDataset",
|
||||
"ParallelInterleaveDatasetV2",
|
||||
"ParallelInterleaveDatasetV3",
|
||||
|
@ -33,10 +33,10 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 4> kDeterministicAttrOps = {
|
||||
"LegacyParallelInterleaveDatasetV2",
|
||||
"ParallelInterleaveDatasetV3",
|
||||
"ParallelInterleaveDatasetV4",
|
||||
"ParallelMapDatasetV2",
|
||||
"ParseExampleDatasetV2",
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
|
@ -46,6 +46,8 @@ namespace experimental {
|
||||
ParallelInterleaveDatasetOp::kCycleLength;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelInterleaveDatasetOp::kBlockLength;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelInterleaveDatasetOp::kDeterministic;
|
||||
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelInterleaveDatasetOp::kBufferOutputElements;
|
||||
@ -90,15 +92,16 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
||||
int64 block_length, bool sloppy, int64 buffer_output_elements,
|
||||
int64 prefetch_input_elements, const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
int64 block_length, DeterminismPolicy deterministic,
|
||||
int64 buffer_output_elements, int64 prefetch_input_elements,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes, int op_version)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
captured_func_(std::move(captured_func)),
|
||||
cycle_length_(cycle_length),
|
||||
block_length_(block_length),
|
||||
sloppy_(sloppy),
|
||||
deterministic_(deterministic),
|
||||
buffer_output_elements_(buffer_output_elements),
|
||||
prefetch_input_elements_(prefetch_input_elements),
|
||||
output_types_(output_types),
|
||||
@ -106,7 +109,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
traceme_metadata_(
|
||||
{{"block_length", strings::Printf("%lld", block_length)},
|
||||
{"cycle_length", strings::Printf("%lld", cycle_length)},
|
||||
{"deterministic", sloppy ? "false" : "true"}}) {
|
||||
{"deterministic",
|
||||
deterministic.IsDeterministic() || deterministic.IsDefault()
|
||||
? "true"
|
||||
: "false"}}),
|
||||
op_version_(op_version) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -114,8 +121,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return absl::make_unique<Iterator>(Iterator::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
||||
name_utils::IteratorPrefixParams params;
|
||||
params.op_version = op_version_;
|
||||
bool deterministic =
|
||||
deterministic_.IsDeterministic() || deterministic_.IsDefault();
|
||||
return absl::make_unique<Iterator>(
|
||||
Iterator::Params{
|
||||
this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
|
||||
deterministic);
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||
@ -125,7 +138,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
name_utils::DatasetDebugStringParams params;
|
||||
params.op_version = op_version_;
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
@ -137,39 +152,62 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
std::vector<std::pair<size_t, Node*>> inputs;
|
||||
std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
|
||||
int input_index = 0;
|
||||
|
||||
Node* input_node;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
|
||||
Node* cycle_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
|
||||
Node* block_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
|
||||
Node* sloppy_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node));
|
||||
Node* buffer_output_elements_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
|
||||
Node* prefetch_input_elements_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
|
||||
inputs.emplace_back(input_index++, input_node);
|
||||
|
||||
std::vector<Node*> other_arguments;
|
||||
DataTypeVector other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
|
||||
&other_arguments_types));
|
||||
list_inputs.emplace_back(input_index++, other_arguments);
|
||||
|
||||
Node* cycle_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
|
||||
inputs.emplace_back(input_index++, cycle_length_node);
|
||||
|
||||
Node* block_length_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
|
||||
inputs.emplace_back(input_index++, block_length_node);
|
||||
|
||||
if (op_version_ == 1) {
|
||||
Node* sloppy_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(deterministic_.IsNondeterministic(), &sloppy_node));
|
||||
inputs.emplace_back(input_index++, sloppy_node);
|
||||
}
|
||||
|
||||
Node* buffer_output_elements_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
|
||||
inputs.emplace_back(input_index++, buffer_output_elements_node);
|
||||
|
||||
Node* prefetch_input_elements_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
|
||||
inputs.emplace_back(input_index++, prefetch_input_elements_node);
|
||||
|
||||
std::vector<std::pair<StringPiece, AttrValue>> attrs;
|
||||
|
||||
AttrValue f;
|
||||
b->BuildAttrValue(captured_func_->func(), &f);
|
||||
attrs.emplace_back(kFunc, f);
|
||||
|
||||
if (op_version_ == 2) {
|
||||
AttrValue deterministic_attr;
|
||||
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
|
||||
attrs.emplace_back(kDeterministic, deterministic_attr);
|
||||
}
|
||||
|
||||
AttrValue other_arguments_types_attr;
|
||||
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
|
||||
attrs.emplace_back(kTarguments, other_arguments_types_attr);
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this,
|
||||
{{0, input_node},
|
||||
{2, cycle_length_node},
|
||||
{3, block_length_node},
|
||||
{4, sloppy_node},
|
||||
{5, buffer_output_elements_node},
|
||||
{6, prefetch_input_elements_node}},
|
||||
{{1, other_arguments}},
|
||||
{{kFunc, f}, {kTarguments, other_arguments_types_attr}}, output));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -226,8 +264,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// an element in `interleave_indices_` or `staging_indices_`.
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
explicit Iterator(const Params& params, bool deterministic)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
deterministic_(deterministic),
|
||||
workers_(dataset()->num_threads()),
|
||||
worker_thread_states_(dataset()->num_threads()) {}
|
||||
|
||||
@ -244,7 +283,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
// It is implemented so that it matches the deterministic interleave
|
||||
// unless getting the next element would block and we are allowed to be
|
||||
// sloppy.
|
||||
// nondeterministic.
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
@ -252,8 +291,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
|
||||
while (!cancelled_) {
|
||||
// Wait for an item to become available, blocking if necessary. If we
|
||||
// are allowed to be sloppy, we can skip over input datasets that do
|
||||
// not have an item readily available.
|
||||
// are allowed to be nondeterministic, we can skip over input datasets
|
||||
// that do not have an item readily available.
|
||||
bool can_produce_elements = false;
|
||||
bool must_wait_for_input = true;
|
||||
for (int64 i = 0; i < interleave_indices_.size(); ++i) {
|
||||
@ -267,9 +306,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
if (!current_worker->outputs.empty()) {
|
||||
// We have an element!
|
||||
next_index_ = index;
|
||||
const bool element_acquired_sloppily = dataset()->sloppy_ && i > 1;
|
||||
const bool element_acquired_sloppily = !deterministic_ && i > 1;
|
||||
if (!element_acquired_sloppily) {
|
||||
// If the element was acquired in the regular (non-sloppy)
|
||||
// If the element was acquired in the regular (deterministic)
|
||||
// order, then advance the current block and cycle pointers to
|
||||
// the next element in the regular order.
|
||||
block_count_++;
|
||||
@ -286,7 +325,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
current_worker->outputs.pop_front();
|
||||
current_worker->cond_var.notify_one();
|
||||
return s;
|
||||
} else if (current_worker->is_producing && !dataset()->sloppy_) {
|
||||
} else if (current_worker->is_producing && deterministic_) {
|
||||
// current_worker.outputs.empty(), and we must wait for this
|
||||
// iterator.
|
||||
if (next_index_ != index) {
|
||||
@ -336,10 +375,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
if (must_wait_for_input) {
|
||||
// Wait for elements to become available.
|
||||
RecordStop(ctx);
|
||||
if (dataset()->sloppy_) {
|
||||
sloppy_cond_var_.wait(l);
|
||||
} else {
|
||||
if (deterministic_) {
|
||||
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
|
||||
} else {
|
||||
any_element_available_cond_var_.wait(l);
|
||||
}
|
||||
RecordStart(ctx);
|
||||
}
|
||||
@ -542,7 +581,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// for the main thread to add arguments to `input`, or (2) waiting for
|
||||
// the main thread to consume an element of `outputs`. The main thread
|
||||
// waits on cond_var if it is waiting for the worker thread to produce
|
||||
// an element into `outputs` (this implies sloppy_==false).
|
||||
// an element into `outputs` (this implies deterministic==true).
|
||||
condition_variable cond_var;
|
||||
|
||||
inline bool MayHaveElements() const {
|
||||
@ -754,10 +793,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// CHECKPOINT_MARKER_C
|
||||
// Non-OK iterator creation status has been notified to the
|
||||
// client.
|
||||
if (dataset()->sloppy_) {
|
||||
sloppy_cond_var_.notify_one();
|
||||
} else {
|
||||
if (deterministic_) {
|
||||
workers_[thread_index].cond_var.notify_one();
|
||||
} else {
|
||||
any_element_available_cond_var_.notify_one();
|
||||
}
|
||||
} else {
|
||||
bool end_of_sequence = false;
|
||||
@ -818,10 +857,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
worker_thread_states_[thread_index].output_elem.status =
|
||||
Status::OK();
|
||||
if (dataset()->sloppy_) {
|
||||
sloppy_cond_var_.notify_one();
|
||||
} else {
|
||||
if (deterministic_) {
|
||||
workers_[thread_index].cond_var.notify_one();
|
||||
} else {
|
||||
any_element_available_cond_var_.notify_one();
|
||||
}
|
||||
// CHECKPOINT_MARKER_E
|
||||
// Output element or iterator status has been sent to the
|
||||
@ -1040,9 +1079,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// Mutex & condition variable to guard mutable iterator internals and
|
||||
// coordinate among worker threads and client thread[s].
|
||||
mutex mu_ ACQUIRED_BEFORE(ckpt_mu_);
|
||||
// The main thread waits on this condition variable if running in sloppy
|
||||
// mode and no values are available.
|
||||
condition_variable sloppy_cond_var_;
|
||||
// The main thread waits on this condition variable if running in
|
||||
// nondeterministic mode and no values are available.
|
||||
condition_variable any_element_available_cond_var_;
|
||||
// Whether outputs must be produced in deterministic order.
|
||||
const bool deterministic_;
|
||||
// Mutex used to wait for a consistent state while checkpointing.
|
||||
// Only Save and Restore require an exclusive lock on this mutex. In
|
||||
// other scenarios we just acquire a shared lock so the pipeline's
|
||||
@ -1087,21 +1128,29 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||
const int64 cycle_length_;
|
||||
const int64 block_length_;
|
||||
const bool sloppy_;
|
||||
const DeterminismPolicy deterministic_;
|
||||
const int64 buffer_output_elements_;
|
||||
const int64 prefetch_input_elements_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
const TraceMeMetadata traceme_metadata_;
|
||||
const int op_version_;
|
||||
};
|
||||
|
||||
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
|
||||
OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
: UnaryDatasetOpKernel(ctx),
|
||||
op_version_(ctx->HasAttr(kDeterministic) ? 2 : 1) {
|
||||
FunctionMetadata::Params params;
|
||||
params.is_multi_device_function = true;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
|
||||
if (op_version_ == 2) {
|
||||
std::string deterministic;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
}
|
||||
@ -1119,8 +1168,17 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
OP_REQUIRES(ctx, block_length > 0,
|
||||
errors::InvalidArgument("`block_length` must be > 0"));
|
||||
|
||||
bool sloppy = false;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
|
||||
if (op_version_ == 1) {
|
||||
bool sloppy = false;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
|
||||
if (sloppy) {
|
||||
deterministic_ =
|
||||
DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
|
||||
} else {
|
||||
deterministic_ =
|
||||
DeterminismPolicy(DeterminismPolicy::Type::kDeterministic);
|
||||
}
|
||||
}
|
||||
|
||||
int64 buffer_output_elements = 0;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
|
||||
@ -1141,8 +1199,9 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
&captured_func));
|
||||
|
||||
*output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
|
||||
block_length, sloppy, buffer_output_elements,
|
||||
prefetch_input_elements, output_types_, output_shapes_);
|
||||
block_length, deterministic_, buffer_output_elements,
|
||||
prefetch_input_elements, output_types_, output_shapes_,
|
||||
op_version_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -1151,9 +1210,13 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("LegacyParallelInterleaveDatasetV2").Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("LegacyParallelInterleaveDatasetV2");
|
||||
|
||||
} // namespace
|
||||
} // namespace experimental
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -27,11 +28,12 @@ namespace experimental {
|
||||
|
||||
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kDatasetType = "ParallelInterleave";
|
||||
static constexpr const char* const kDatasetType = "LegacyParallelInterleave";
|
||||
static constexpr const char* const kInputDataset = "input_dataset";
|
||||
static constexpr const char* const kOtherArguments = "other_arguments";
|
||||
static constexpr const char* const kCycleLength = "cycle_length";
|
||||
static constexpr const char* const kBlockLength = "block_length";
|
||||
static constexpr const char* const kDeterministic = "deterministic";
|
||||
static constexpr const char* const kSloppy = "sloppy";
|
||||
static constexpr const char* const kBufferOutputElements =
|
||||
"buffer_output_elements";
|
||||
@ -50,10 +52,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
const int op_version_;
|
||||
|
||||
std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
DeterminismPolicy deterministic_;
|
||||
};
|
||||
|
||||
} // namespace experimental
|
||||
|
@ -20,33 +20,37 @@ namespace experimental {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "parallel_interleave_dataset";
|
||||
constexpr int kOpVersion = 2;
|
||||
|
||||
class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
public:
|
||||
template <typename T>
|
||||
ParallelInterleaveDatasetParams(
|
||||
T input_dataset_params, std::vector<Tensor> other_arguments,
|
||||
int64 cycle_length, int64 block_length, bool sloppy,
|
||||
int64 cycle_length, int64 block_length, const std::string& deterministic,
|
||||
int64 buffer_output_elements, int64 prefetch_input_elements,
|
||||
FunctionDefHelper::AttrValueWrapper func,
|
||||
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
|
||||
DataTypeVector output_dtypes,
|
||||
std::vector<PartialTensorShape> output_shapes, string node_name)
|
||||
const DataTypeVector& output_dtypes,
|
||||
const std::vector<PartialTensorShape>& output_shapes, string node_name)
|
||||
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
|
||||
std::move(node_name)),
|
||||
other_arguments_(std::move(other_arguments)),
|
||||
cycle_length_(cycle_length),
|
||||
block_length_(block_length),
|
||||
sloppy_(sloppy),
|
||||
deterministic_(deterministic),
|
||||
buffer_output_elements_(buffer_output_elements),
|
||||
prefetch_input_elements_(prefetch_input_elements),
|
||||
func_(std::move(func)),
|
||||
func_lib_(std::move(func_lib)),
|
||||
type_arguments_(std::move(type_arguments)) {
|
||||
input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
|
||||
iterator_prefix_ =
|
||||
name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
|
||||
input_dataset_params.iterator_prefix());
|
||||
op_version_ = kOpVersion;
|
||||
name_utils::IteratorPrefixParams params;
|
||||
params.op_version = op_version_;
|
||||
iterator_prefix_ = name_utils::IteratorPrefix(
|
||||
input_dataset_params.dataset_type(),
|
||||
input_dataset_params.iterator_prefix(), params);
|
||||
}
|
||||
|
||||
std::vector<Tensor> GetInputTensors() const override {
|
||||
@ -55,7 +59,6 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
CreateTensor<int64>(TensorShape({}), {cycle_length_}));
|
||||
input_tensors.emplace_back(
|
||||
CreateTensor<int64>(TensorShape({}), {block_length_}));
|
||||
input_tensors.emplace_back(CreateTensor<bool>(TensorShape({}), {sloppy_}));
|
||||
input_tensors.emplace_back(
|
||||
CreateTensor<int64>(TensorShape({}), {buffer_output_elements_}));
|
||||
input_tensors.emplace_back(
|
||||
@ -71,7 +74,6 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
}
|
||||
input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength);
|
||||
input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength);
|
||||
input_names->emplace_back(ParallelInterleaveDatasetOp::kSloppy);
|
||||
input_names->emplace_back(
|
||||
ParallelInterleaveDatasetOp::kBufferOutputElements);
|
||||
input_names->emplace_back(
|
||||
@ -82,6 +84,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
Status GetAttributes(AttributeVector* attr_vector) const override {
|
||||
*attr_vector = {
|
||||
{ParallelInterleaveDatasetOp::kFunc, func_},
|
||||
{ParallelInterleaveDatasetOp::kDeterministic, deterministic_},
|
||||
{ParallelInterleaveDatasetOp::kTarguments, type_arguments_},
|
||||
{ParallelInterleaveDatasetOp::kOutputShapes, output_shapes_},
|
||||
{ParallelInterleaveDatasetOp::kOutputTypes, output_dtypes_}};
|
||||
@ -98,7 +101,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
std::vector<Tensor> other_arguments_;
|
||||
int64 cycle_length_;
|
||||
int64 block_length_;
|
||||
bool sloppy_;
|
||||
std::string deterministic_;
|
||||
int64 buffer_output_elements_;
|
||||
int64 prefetch_input_elements_;
|
||||
FunctionDefHelper::AttrValueWrapper func_;
|
||||
@ -117,7 +120,7 @@ FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc(
|
||||
{TensorSliceDatasetOp::kOutputShapes, output_shapes}});
|
||||
}
|
||||
|
||||
// Test case 1: cycle_length = 1, block_length = 1, sloppy = false,
|
||||
// Test case 1: cycle_length = 1, block_length = 1, deterministic = true,
|
||||
// buffer_output_elements = 1, prefetch_input_elements = 1.
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
@ -129,7 +132,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/1,
|
||||
/*block_length=*/1,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*buffer_output_elements=*/1,
|
||||
/*prefetch_input_elements=*/1,
|
||||
/*func=*/
|
||||
@ -143,7 +146,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// Test case 2: cycle_length = 2, block_length = 1, sloppy = false,
|
||||
// Test case 2: cycle_length = 2, block_length = 1, deterministic = true,
|
||||
// buffer_output_elements = 1, prefetch_input_elements = 0.
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
@ -155,7 +158,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/2,
|
||||
/*block_length=*/1,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*buffer_output_elements=*/1,
|
||||
/*prefetch_input_elements=*/0,
|
||||
/*func=*/
|
||||
@ -169,7 +172,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// Test case 3: cycle_length = 3, block_length = 1, sloppy = true,
|
||||
// Test case 3: cycle_length = 3, block_length = 1, deterministic = false,
|
||||
// buffer_output_elements = 3, prefetch_input_elements = 2.
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
@ -181,7 +184,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/3,
|
||||
/*block_length=*/1,
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*buffer_output_elements=*/3,
|
||||
/*prefetch_input_elements=*/2,
|
||||
/*func=*/
|
||||
@ -195,7 +198,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// Test case 4: cycle_length = 5, block_length = 1, sloppy = true
|
||||
// Test case 4: cycle_length = 5, block_length = 1, deterministic = false
|
||||
// buffer_output_elements = 1, prefetch_input_elements = 2.
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
@ -207,7 +210,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/5,
|
||||
/*block_length=*/1,
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*buffer_output_elements=*/1,
|
||||
/*prefetch_input_elements=*/2,
|
||||
/*func=*/
|
||||
@ -221,7 +224,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// Test case 5: cycle_length = 2, block_length = 2, sloppy = false
|
||||
// Test case 5: cycle_length = 2, block_length = 2, deterministic = true
|
||||
// buffer_output_elements = 2, prefetch_input_elements = 2.
|
||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
@ -233,7 +236,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/2,
|
||||
/*block_length=*/2,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*buffer_output_elements=*/2,
|
||||
/*prefetch_input_elements=*/2,
|
||||
/*func=*/
|
||||
@ -256,7 +259,7 @@ ParallelInterleaveDatasetParams EmptyInputParams() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/2,
|
||||
/*block_length=*/2,
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*buffer_output_elements=*/2,
|
||||
/*prefetch_input_elements=*/2,
|
||||
/*func=*/
|
||||
@ -280,7 +283,7 @@ ParallelInterleaveDatasetParams InvalidCycleLengthParams() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/0,
|
||||
/*block_length=*/1,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*buffer_output_elements=*/1,
|
||||
/*prefetch_input_elements=*/1,
|
||||
/*func=*/
|
||||
@ -304,7 +307,7 @@ ParallelInterleaveDatasetParams InvalidBlockLengthParams() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/1,
|
||||
/*block_length=*/-1,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*buffer_output_elements=*/1,
|
||||
/*prefetch_input_elements=*/1,
|
||||
/*func=*/
|
||||
@ -328,7 +331,7 @@ ParallelInterleaveDatasetParams InvalidBufferOutputElementsParams() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/1,
|
||||
/*block_length=*/1,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*buffer_output_elements=*/0,
|
||||
/*prefetch_input_elements=*/1,
|
||||
/*func=*/
|
||||
@ -352,7 +355,7 @@ ParallelInterleaveDatasetParams InvalidPrefetchInputElementsParams() {
|
||||
/*other_arguments=*/{},
|
||||
/*cycle_length=*/1,
|
||||
/*block_length=*/1,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*buffer_output_elements=*/1,
|
||||
/*prefetch_input_elements=*/-1,
|
||||
/*func=*/
|
||||
@ -412,8 +415,10 @@ TEST_F(ParallelInterleaveDatasetOpTest, DatasetNodeName) {
|
||||
TEST_F(ParallelInterleaveDatasetOpTest, DatasetTypeString) {
|
||||
auto dataset_params = ParallelInterleaveDatasetParams1();
|
||||
TF_ASSERT_OK(Initialize(dataset_params));
|
||||
name_utils::OpNameParams params;
|
||||
params.op_version = dataset_params.op_version();
|
||||
TF_ASSERT_OK(CheckDatasetTypeString(
|
||||
name_utils::OpName(ParallelInterleaveDatasetOp::kDatasetType)));
|
||||
name_utils::OpName(ParallelInterleaveDatasetOp::kDatasetType, params)));
|
||||
}
|
||||
|
||||
TEST_F(ParallelInterleaveDatasetOpTest, DatasetOutputDtypes) {
|
||||
@ -461,9 +466,11 @@ TEST_F(ParallelInterleaveDatasetOpTest, IteratorOutputShapes) {
|
||||
TEST_F(ParallelInterleaveDatasetOpTest, IteratorPrefix) {
|
||||
auto dataset_params = ParallelInterleaveDatasetParams1();
|
||||
TF_ASSERT_OK(Initialize(dataset_params));
|
||||
name_utils::IteratorPrefixParams params;
|
||||
params.op_version = dataset_params.op_version();
|
||||
TF_ASSERT_OK(CheckIteratorPrefix(
|
||||
name_utils::IteratorPrefix(ParallelInterleaveDatasetOp::kDatasetType,
|
||||
dataset_params.iterator_prefix())));
|
||||
dataset_params.iterator_prefix(), params)));
|
||||
}
|
||||
|
||||
std::vector<IteratorSaveAndRestoreTestCase<ParallelInterleaveDatasetParams>>
|
||||
|
@ -561,6 +561,26 @@ REGISTER_OP("ParallelInterleaveDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
// This is the V2 of ParallelInterleaveDataset, renamed to differentiate it
|
||||
// from the non-experimental ParallelInterleaveDataset op.
|
||||
REGISTER_OP("LegacyParallelInterleaveDatasetV2")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Input("cycle_length: int64")
|
||||
.Input("block_length: int64")
|
||||
.Input("buffer_output_elements: int64")
|
||||
.Input("prefetch_input_elements: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
// "true", "false", or "default".
|
||||
.Attr("deterministic: string = 'default'")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
// This op is no longer used. We keep it so that we can read graphs written by
|
||||
// old versions of TensorFlow.
|
||||
REGISTER_OP("ExperimentalParallelInterleaveDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
|
@ -466,6 +466,7 @@ tf_py_test(
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/experimental/ops:interleave_ops",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"@six_archive//:six",
|
||||
|
@ -26,6 +26,7 @@ import numpy as np
|
||||
from six.moves import zip_longest
|
||||
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
@ -729,6 +730,40 @@ class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
results.append(elements)
|
||||
self.assertAllEqual(results[0], results[1])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
sloppy=[None, True, False], global_determinism=[True, False])))
|
||||
def testDeterminismConfiguration(self, sloppy, global_determinism):
|
||||
if sloppy is None:
|
||||
expect_determinism = global_determinism
|
||||
else:
|
||||
expect_determinism = not sloppy
|
||||
elements = list(range(1000))
|
||||
|
||||
def dataset_fn(delay_ms):
|
||||
|
||||
def interleave_fn(x):
|
||||
ds = dataset_ops.Dataset.from_tensors(x)
|
||||
if math_ops.equal(x, 0):
|
||||
ds = ds.apply(testing.sleep(delay_ms * 1000))
|
||||
else:
|
||||
ds = ds.apply(testing.sleep(0))
|
||||
return ds
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||
dataset = dataset.apply(
|
||||
interleave_ops.parallel_interleave(
|
||||
interleave_fn, cycle_length=10, sloppy=sloppy))
|
||||
|
||||
opts = dataset_ops.Options()
|
||||
opts.experimental_deterministic = global_determinism
|
||||
dataset = dataset.with_options(opts)
|
||||
return dataset
|
||||
|
||||
self.checkDeterminism(dataset_fn, expect_determinism, elements)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -76,9 +76,11 @@ def parallel_interleave(map_func,
|
||||
cycle_length: The number of input `Dataset`s to interleave from in parallel.
|
||||
block_length: The number of consecutive elements to pull from an input
|
||||
`Dataset` before advancing to the next input `Dataset`.
|
||||
sloppy: If false, elements are produced in deterministic order. Otherwise,
|
||||
the implementation is allowed, for the sake of expediency, to produce
|
||||
elements in a non-deterministic order.
|
||||
sloppy: A boolean controlling whether determinism should be traded for
|
||||
performance by allowing elements to be produced out of order. If
|
||||
`sloppy` is `None`, the `tf.data.Options.experimental_deterministic`
|
||||
dataset option (`True` by default) is used to decide whether to enforce a
|
||||
deterministic order.
|
||||
buffer_output_elements: The number of elements each iterator being
|
||||
interleaved should buffer (similar to the `.prefetch()` transformation for
|
||||
each interleaved iterator).
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -248,8 +249,9 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
cycle_length, dtype=dtypes.int64, name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._sloppy = ops.convert_to_tensor(
|
||||
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||
if sloppy is not None:
|
||||
self._sloppy = ops.convert_to_tensor(
|
||||
sloppy, dtype=dtypes.bool, name="sloppy")
|
||||
self._buffer_output_elements = convert.optional_param_to_tensor(
|
||||
"buffer_output_elements",
|
||||
buffer_output_elements,
|
||||
@ -258,16 +260,34 @@ class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
|
||||
"prefetch_input_elements",
|
||||
prefetch_input_elements,
|
||||
argument_default=2 * cycle_length)
|
||||
variant_tensor = ged_ops.parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
if sloppy is None or compat.forward_compatible(2020, 3, 6):
|
||||
if sloppy is None:
|
||||
self._deterministic = "default"
|
||||
elif sloppy:
|
||||
self._deterministic = "false"
|
||||
else:
|
||||
self._deterministic = "true"
|
||||
variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
deterministic=self._deterministic,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = ged_ops.parallel_interleave_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._sloppy,
|
||||
self._buffer_output_elements,
|
||||
self._prefetch_input_elements,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
super(ParallelInterleaveDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
|
@ -1964,6 +1964,10 @@ tf_module {
|
||||
name: "LeftShift"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LegacyParallelInterleaveDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Less"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1964,6 +1964,10 @@ tf_module {
|
||||
name: "LeftShift"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "LegacyParallelInterleaveDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Less"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user