Enable op-level dataset determinism configuration.
Users can control determinism at a per-op level by specifying `deterministic` when calling Dataset.interleave. The `deterministic` argument takes higher priority than the `experimental_deterministic` dataset option. PiperOrigin-RevId: 293025557 Change-Id: I6f2efca6cf3dc6e9625256d5a602e1b3b5e7506e
This commit is contained in:
parent
b11551a34d
commit
b90657adc8
@ -0,0 +1,86 @@
|
||||
op {
|
||||
graph_op_name: "ParallelInterleaveDatasetV3"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
Dataset that produces a stream of arguments for the function `f`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "other_arguments"
|
||||
description: <<END
|
||||
Additional arguments to pass to `f` beyond those produced by `input_dataset`.
|
||||
Evaluated once when the dataset is instantiated.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "cycle_length"
|
||||
description: <<END
|
||||
Number of datasets (each created by applying `f` to the elements of
|
||||
`input_dataset`) among which the `ParallelInterleaveDatasetV2` will cycle in a
|
||||
round-robin fashion.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "block_length"
|
||||
description: <<END
|
||||
Number of elements at a time to produce from each interleaved invocation of a
|
||||
dataset returned by `f`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_parallel_calls"
|
||||
description: <<END
|
||||
Determines the number of threads that should be used for fetching data from
|
||||
input datasets in parallel. The Python API `tf.data.experimental.AUTOTUNE`
|
||||
constant can be used to indicate that the level of parallelism should be autotuned.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
description: <<END
|
||||
A function mapping elements of `input_dataset`, concatenated with
|
||||
`other_arguments`, to a Dataset variant that contains elements matching
|
||||
`output_types` and `output_shapes`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "deterministic"
|
||||
description: <<END
|
||||
A string indicating the op-level determinism to use. Deterministic controls
|
||||
whether the interleave is allowed to return elements out of order if the next
|
||||
element to be returned isn't available, but a later element is. Options are
|
||||
"true", "false", and "default". "default" indicates that determinism should be
|
||||
decided by the `experimental_deterministic` parameter of `tf.data.Options`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
description: <<END
|
||||
Types of the elements of `other_arguments`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
}
|
||||
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
|
||||
description: <<END
|
||||
The resulting dataset is similar to the `InterleaveDataset`, except that the
|
||||
dataset will fetch records from the interleaved datasets in parallel.
|
||||
|
||||
The `tf.data` Python API creates instances of this op from
|
||||
`Dataset.interleave()` when the `num_parallel_calls` parameter of that method
|
||||
is set to any value other than `None`.
|
||||
|
||||
By default, the output of this dataset will be deterministic, which may result
|
||||
in the dataset blocking if the next data item to be returned isn't available.
|
||||
In order to avoid head-of-line blocking, one can either set the `deterministic`
|
||||
attribute to "false", or leave it as "default" and set the
|
||||
`experimental_deterministic` parameter of `tf.data.Options` to `False`.
|
||||
This can improve performance at the expense of non-determinism.
|
||||
END
|
||||
}
|
@ -89,12 +89,13 @@ constexpr std::array<const char*, 28> kPassThroughOps = {
|
||||
};
|
||||
|
||||
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
|
||||
constexpr std::array<const char*, 5> kFuncDatasetOps = {
|
||||
constexpr std::array<const char*, 6> kFuncDatasetOps = {
|
||||
"ExperimentalParallelInterleaveDataset",
|
||||
"FlatMapDataset",
|
||||
"InterleaveDataset",
|
||||
"ParallelInterleaveDataset",
|
||||
"ParallelInterleaveDatasetV2"
|
||||
"ParallelInterleaveDatasetV2",
|
||||
"ParallelInterleaveDatasetV3"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
|
||||
|
@ -33,10 +33,11 @@ namespace {
|
||||
constexpr char kLegacyAutotune[] = "legacy_autotune";
|
||||
constexpr char kPrefetchDataset[] = "PrefetchDataset";
|
||||
|
||||
constexpr std::array<const char*, 4> kAsyncDatasetOps = {
|
||||
constexpr std::array<const char*, 5> kAsyncDatasetOps = {
|
||||
"ExperimentalMapAndBatchDataset",
|
||||
"ParallelMapDataset",
|
||||
"ParallelInterleaveDatasetV2",
|
||||
"ParallelInterleaveDatasetV3",
|
||||
"MapAndBatchDataset",
|
||||
};
|
||||
|
||||
|
@ -25,6 +25,18 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
constexpr std::array<const char*, 3> kSloppyAttrOps = {
|
||||
"ParallelInterleaveDatasetV2",
|
||||
"ParallelMapDataset",
|
||||
"ParseExampleDataset",
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 1> kDeterministicAttrOps = {
|
||||
"ParallelInterleaveDatasetV3",
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
Status MakeSloppy::OptimizeAndCollectStats(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
@ -33,11 +45,20 @@ Status MakeSloppy::OptimizeAndCollectStats(Cluster* cluster,
|
||||
MutableGraphView graph(output);
|
||||
|
||||
for (NodeDef& node : *output->mutable_node()) {
|
||||
if (node.op() == "ParallelInterleaveDatasetV2" ||
|
||||
node.op() == "ParallelMapDataset" ||
|
||||
node.op() == "ParseExampleDataset") {
|
||||
for (const auto& op_name : kSloppyAttrOps) {
|
||||
if (node.op() == op_name) {
|
||||
(*node.mutable_attr())["sloppy"].set_b(true);
|
||||
stats->num_changes++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (const auto& op_name : kDeterministicAttrOps) {
|
||||
if (node.op() == op_name &&
|
||||
node.attr().at("deterministic").s() == "default") {
|
||||
(*node.mutable_attr())["deterministic"].set_s("false");
|
||||
stats->num_changes++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -760,5 +760,43 @@ std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
|
||||
std::move(runner), std::placeholders::_1);
|
||||
}
|
||||
|
||||
Status DeterminismPolicy::FromString(const std::string& s,
|
||||
DeterminismPolicy* out) {
|
||||
DeterminismPolicy::Type type;
|
||||
if (s == DeterminismPolicy::kDeterministic) {
|
||||
type = DeterminismPolicy::Type::kDeterministic;
|
||||
} else if (s == DeterminismPolicy::kNondeterministic) {
|
||||
type = DeterminismPolicy::Type::kNondeterministic;
|
||||
} else if (s == DeterminismPolicy::kDefault) {
|
||||
type = DeterminismPolicy::Type::kDefault;
|
||||
} else {
|
||||
return errors::InvalidArgument("Unrecognized determinism policy: ", s);
|
||||
}
|
||||
*out = DeterminismPolicy(type);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DeterminismPolicy::DeterminismPolicy(bool is_deterministic) {
|
||||
if (is_deterministic) {
|
||||
determinism_ = DeterminismPolicy::Type::kDeterministic;
|
||||
} else {
|
||||
determinism_ = DeterminismPolicy::Type::kNondeterministic;
|
||||
}
|
||||
}
|
||||
|
||||
std::string DeterminismPolicy::String() const {
|
||||
switch (determinism_) {
|
||||
case DeterminismPolicy::Type::kDeterministic:
|
||||
return DeterminismPolicy::kDeterministic;
|
||||
case DeterminismPolicy::Type::kNondeterministic:
|
||||
return DeterminismPolicy::kNondeterministic;
|
||||
case DeterminismPolicy::Type::kDefault:
|
||||
return DeterminismPolicy::kDefault;
|
||||
default:
|
||||
LOG(ERROR) << "Unrecognized determinism value";
|
||||
return "Unrecognized";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -148,6 +148,46 @@ Status HashTensor(const Tensor& tensor, uint64* hash);
|
||||
// the same between TensorFlow builds.
|
||||
Status HashGraph(const GraphDef& graph, uint64* hash);
|
||||
|
||||
// Dataset op level determinism policy.
|
||||
class DeterminismPolicy {
|
||||
public:
|
||||
enum class Type : int {
|
||||
// The op must produce elements deterministically.
|
||||
kDeterministic,
|
||||
// The op may relax determinism to improve performance.
|
||||
kNondeterministic,
|
||||
// The determinism policy is not specified at the op level. In this case we
|
||||
// use the experimental_deterministic dataset option to determine the
|
||||
// determinism policy.
|
||||
kDefault,
|
||||
};
|
||||
static constexpr const char* const kDeterministic = "true";
|
||||
static constexpr const char* const kNondeterministic = "false";
|
||||
static constexpr const char* const kDefault = "default";
|
||||
|
||||
DeterminismPolicy() : determinism_(Type::kDefault) {}
|
||||
explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {}
|
||||
// Creates a DeterminismPolicy with Type kDeterministic or
|
||||
// kNondeterministic, depending on the values of `is_deterministic`.
|
||||
explicit DeterminismPolicy(bool is_deterministic);
|
||||
|
||||
static Status FromString(const std::string& s, DeterminismPolicy* out);
|
||||
|
||||
// Returns the string representing the determinism policy. This will be one of
|
||||
// the string constants defined above.
|
||||
std::string String() const;
|
||||
|
||||
/// Convenience methods for checking the DeterminismPolicy::Type.
|
||||
bool IsDeterministic() const { return determinism_ == Type::kDeterministic; }
|
||||
bool IsNondeterministic() const {
|
||||
return determinism_ == Type::kNondeterministic;
|
||||
}
|
||||
bool IsDefault() const { return determinism_ == Type::kDefault; }
|
||||
|
||||
private:
|
||||
Type determinism_;
|
||||
};
|
||||
|
||||
// Helper class for reading data from a vector of VariantTensorData objects.
|
||||
class VariantTensorDataReader : public IteratorStateReader {
|
||||
public:
|
||||
|
@ -241,6 +241,34 @@ TEST(DatasetUtilsTest, RunnerWithMaxParallelism) {
|
||||
runner(fn);
|
||||
}
|
||||
|
||||
TEST(DatasetUtilsTest, ParseDeterminismPolicy) {
|
||||
DeterminismPolicy determinism;
|
||||
TF_ASSERT_OK(DeterminismPolicy::FromString("true", &determinism));
|
||||
EXPECT_TRUE(determinism.IsDeterministic());
|
||||
TF_ASSERT_OK(DeterminismPolicy::FromString("false", &determinism));
|
||||
EXPECT_TRUE(determinism.IsNondeterministic());
|
||||
TF_ASSERT_OK(DeterminismPolicy::FromString("default", &determinism));
|
||||
EXPECT_TRUE(determinism.IsDefault());
|
||||
}
|
||||
|
||||
TEST(DatasetUtilsTest, DeterminismString) {
|
||||
for (auto s : {"true", "false", "default"}) {
|
||||
DeterminismPolicy determinism;
|
||||
TF_ASSERT_OK(DeterminismPolicy::FromString(s, &determinism));
|
||||
EXPECT_TRUE(s == determinism.String());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DatasetUtilsTest, BoolConstructor) {
|
||||
EXPECT_TRUE(DeterminismPolicy(true).IsDeterministic());
|
||||
EXPECT_FALSE(DeterminismPolicy(true).IsNondeterministic());
|
||||
EXPECT_FALSE(DeterminismPolicy(true).IsDefault());
|
||||
|
||||
EXPECT_TRUE(DeterminismPolicy(false).IsNondeterministic());
|
||||
EXPECT_FALSE(DeterminismPolicy(false).IsDeterministic());
|
||||
EXPECT_FALSE(DeterminismPolicy(false).IsDefault());
|
||||
}
|
||||
|
||||
TEST_F(DatasetHashUtilsTest, HashFunctionSameFunctionDifferentNames) {
|
||||
FunctionDefLibrary fl;
|
||||
|
||||
|
@ -67,6 +67,8 @@ namespace data {
|
||||
ParallelInterleaveDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelInterleaveDatasetOp::kOutputShapes;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelInterleaveDatasetOp::kDeterministic;
|
||||
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
|
||||
|
||||
constexpr char kTfDataParallelInterleaveWorkerPool[] =
|
||||
@ -115,18 +117,19 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
|
||||
int64 block_length, int64 num_parallel_calls, bool sloppy,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
int64 block_length, int64 num_parallel_calls,
|
||||
DeterminismPolicy deterministic, const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes, int op_version)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
captured_func_(std::move(captured_func)),
|
||||
cycle_length_(cycle_length),
|
||||
block_length_(block_length),
|
||||
num_parallel_calls_(num_parallel_calls),
|
||||
sloppy_(sloppy),
|
||||
deterministic_(deterministic),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes) {
|
||||
output_shapes_(output_shapes),
|
||||
op_version_(op_version) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -136,12 +139,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
const string& prefix) const override {
|
||||
name_utils::IteratorPrefixParams params;
|
||||
params.op_version = op_version_;
|
||||
bool deterministic =
|
||||
deterministic_.IsDeterministic() || deterministic_.IsDefault();
|
||||
return absl::make_unique<ParallelInterleaveIterator>(
|
||||
ParallelInterleaveIterator::Params{
|
||||
this,
|
||||
name_utils::IteratorPrefix(
|
||||
ParallelInterleaveDatasetOp::kDatasetType, prefix, params)},
|
||||
sloppy_);
|
||||
deterministic);
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||
@ -179,30 +184,40 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
DataTypeVector other_arguments_types;
|
||||
TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
|
||||
&other_arguments_types));
|
||||
|
||||
std::vector<std::pair<StringPiece, AttrValue>> attrs;
|
||||
AttrValue f;
|
||||
b->BuildAttrValue(captured_func_->func(), &f);
|
||||
attrs.emplace_back(kFunc, f);
|
||||
|
||||
AttrValue other_arguments_types_attr;
|
||||
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
|
||||
attrs.emplace_back(kTarguments, other_arguments_types_attr);
|
||||
|
||||
if (op_version_ == 2) {
|
||||
AttrValue sloppy_attr;
|
||||
b->BuildAttrValue(sloppy_, &sloppy_attr);
|
||||
b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
|
||||
attrs.emplace_back(kSloppy, sloppy_attr);
|
||||
}
|
||||
if (op_version_ == 3) {
|
||||
AttrValue deterministic_attr;
|
||||
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
|
||||
attrs.emplace_back(kDeterministic, deterministic_attr);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this,
|
||||
{{0, input_node},
|
||||
{2, cycle_length_node},
|
||||
{3, block_length_node},
|
||||
{4, num_parallel_calls_node}},
|
||||
{{1, other_arguments}},
|
||||
{{kFunc, f},
|
||||
{kTarguments, other_arguments_types_attr},
|
||||
{kSloppy, sloppy_attr}},
|
||||
output));
|
||||
{{1, other_arguments}}, attrs, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class ParallelInterleaveIterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
ParallelInterleaveIterator(const Params& params, bool sloppy)
|
||||
ParallelInterleaveIterator(const Params& params, bool deterministic)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
per_iterator_prefetch_(
|
||||
static_cast<int>(params.dataset->block_length_ *
|
||||
@ -215,7 +230,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
num_parallel_calls_(std::make_shared<model::SharedState>(
|
||||
params.dataset->num_parallel_calls_, mu_,
|
||||
num_parallel_calls_cond_var_)),
|
||||
sloppy_(sloppy),
|
||||
deterministic_(deterministic),
|
||||
current_elements_(params.dataset->cycle_length_) {}
|
||||
|
||||
~ParallelInterleaveIterator() override {
|
||||
@ -236,7 +251,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
",cycle_length=", dataset()->cycle_length_,
|
||||
",block_length=", dataset()->block_length_,
|
||||
",autotune=", dataset()->num_parallel_calls_ == model::kAutotune,
|
||||
",deterministic=", !sloppy_, "#");
|
||||
",deterministic=", deterministic_, "#");
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
@ -279,12 +294,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
EnsureThreadsStarted();
|
||||
while (!cancelled_ && !Consume(&result)) {
|
||||
RecordStop(ctx);
|
||||
if (sloppy_) {
|
||||
sloppy_cond_var_.wait(l);
|
||||
} else {
|
||||
if (deterministic_) {
|
||||
VLOG(3) << "Blocked waiting for element "
|
||||
<< current_elements_[cycle_index_]->id;
|
||||
current_elements_[cycle_index_]->cond_var.wait(l);
|
||||
} else {
|
||||
any_element_available_cond_var_.wait(l);
|
||||
}
|
||||
RecordStart(ctx);
|
||||
}
|
||||
@ -470,7 +485,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
while (wait && outstanding_threads_ > 0) {
|
||||
outstanding_threads_finished_cond_var_.wait(l);
|
||||
}
|
||||
sloppy_cond_var_.notify_all();
|
||||
any_element_available_cond_var_.notify_all();
|
||||
zero_active_workers_cond_var_.notify_all();
|
||||
}
|
||||
|
||||
@ -522,11 +537,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// points to a valid result or is null if end of input has been reached.
|
||||
bool Consume(std::shared_ptr<Result>* result)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!sloppy_) {
|
||||
if (deterministic_) {
|
||||
return ConsumeHelper(result);
|
||||
}
|
||||
// If we are allowed to be sloppy (i.e. return results out of order),
|
||||
// try to find an element in the cycle that has a result available.
|
||||
// If we are allowed to be nondeterministic (i.e. return results out of
|
||||
// order), try to find an element in the cycle that has a result
|
||||
// available.
|
||||
for (int i = 0; i < dataset()->cycle_length_; ++i) {
|
||||
if (ConsumeHelper(result)) {
|
||||
return true;
|
||||
@ -916,10 +932,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
void NotifyElementUpdate(std::shared_ptr<Element> element)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (sloppy_) {
|
||||
sloppy_cond_var_.notify_one();
|
||||
} else {
|
||||
if (deterministic_) {
|
||||
element->cond_var.notify_one();
|
||||
} else {
|
||||
any_element_available_cond_var_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1337,11 +1353,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
int num_current_workers_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
// Condition variable to signal that a result has been produced by some
|
||||
// element thread. Only used when `sloppy_` is true.
|
||||
condition_variable sloppy_cond_var_;
|
||||
// element thread. Only used when `deterministic` is false.
|
||||
condition_variable any_element_available_cond_var_;
|
||||
|
||||
// Determines whether outputs can be produced in non-deterministic order.
|
||||
const bool sloppy_;
|
||||
// Determines whether outputs can be produced in deterministic order.
|
||||
const bool deterministic_;
|
||||
|
||||
// Iterator for input elements.
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
@ -1400,22 +1416,37 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
const int64 cycle_length_;
|
||||
const int64 block_length_;
|
||||
const int64 num_parallel_calls_;
|
||||
const int op_version_ = 2;
|
||||
const bool sloppy_;
|
||||
const DeterminismPolicy deterministic_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
const int op_version_;
|
||||
};
|
||||
|
||||
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
|
||||
OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
: UnaryDatasetOpKernel(ctx), op_version_(ctx->HasAttr(kSloppy) ? 2 : 3) {
|
||||
FunctionMetadata::Params params;
|
||||
params.is_multi_device_function = true;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy_));
|
||||
if (op_version_ == 2) {
|
||||
bool sloppy;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy));
|
||||
if (sloppy) {
|
||||
deterministic_ =
|
||||
DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
|
||||
} else {
|
||||
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
|
||||
}
|
||||
}
|
||||
if (op_version_ == 3) {
|
||||
std::string deterministic;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
|
||||
}
|
||||
}
|
||||
|
||||
void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
@ -1455,14 +1486,17 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
}
|
||||
|
||||
*output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
|
||||
block_length, num_parallel_calls, sloppy_,
|
||||
output_types_, output_shapes_);
|
||||
block_length, num_parallel_calls, deterministic_,
|
||||
output_types_, output_shapes_, op_version_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV3").Device(DEVICE_CPU),
|
||||
ParallelInterleaveDatasetOp);
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV2");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV3");
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -33,6 +34,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kTarguments = "Targuments";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kDeterministic = "deterministic";
|
||||
static constexpr const char* const kSloppy = "sloppy";
|
||||
|
||||
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx);
|
||||
@ -43,10 +45,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
const int op_version_;
|
||||
std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
bool sloppy_;
|
||||
DeterminismPolicy deterministic_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -18,21 +18,19 @@ namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "parallel_interleave_dataset";
|
||||
constexpr int kOpVersion = 2;
|
||||
constexpr int kOpVersion = 3;
|
||||
|
||||
class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
public:
|
||||
template <typename T>
|
||||
ParallelInterleaveDatasetParams(T input_dataset_params,
|
||||
std::vector<Tensor> other_arguments,
|
||||
int64 cycle_length, int64 block_length,
|
||||
int64 num_parallel_calls,
|
||||
ParallelInterleaveDatasetParams(
|
||||
T input_dataset_params, std::vector<Tensor> other_arguments,
|
||||
int64 cycle_length, int64 block_length, int64 num_parallel_calls,
|
||||
FunctionDefHelper::AttrValueWrapper func,
|
||||
std::vector<FunctionDef> func_lib,
|
||||
DataTypeVector type_arguments,
|
||||
DataTypeVector output_dtypes,
|
||||
std::vector<PartialTensorShape> output_shapes,
|
||||
bool sloppy, string node_name)
|
||||
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
|
||||
const DataTypeVector& output_dtypes,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
const std::string& deterministic, const std::string& node_name)
|
||||
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
|
||||
std::move(node_name)),
|
||||
other_arguments_(std::move(other_arguments)),
|
||||
@ -42,7 +40,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
func_(std::move(func)),
|
||||
func_lib_(std::move(func_lib)),
|
||||
type_arguments_(std::move(type_arguments)),
|
||||
sloppy_(sloppy) {
|
||||
deterministic_(deterministic) {
|
||||
input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
|
||||
op_version_ = kOpVersion;
|
||||
name_utils::IteratorPrefixParams params;
|
||||
@ -78,10 +76,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
Status GetAttributes(AttributeVector* attr_vector) const override {
|
||||
*attr_vector = {
|
||||
{ParallelInterleaveDatasetOp::kFunc, func_},
|
||||
{ParallelInterleaveDatasetOp::kDeterministic, deterministic_},
|
||||
{ParallelInterleaveDatasetOp::kTarguments, type_arguments_},
|
||||
{ParallelInterleaveDatasetOp::kOutputShapes, output_shapes_},
|
||||
{ParallelInterleaveDatasetOp::kOutputTypes, output_dtypes_},
|
||||
{ParallelInterleaveDatasetOp::kSloppy, sloppy_}};
|
||||
{ParallelInterleaveDatasetOp::kOutputTypes, output_dtypes_}};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -99,7 +97,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
|
||||
FunctionDefHelper::AttrValueWrapper func_;
|
||||
std::vector<FunctionDef> func_lib_;
|
||||
DataTypeVector type_arguments_;
|
||||
bool sloppy_;
|
||||
std::string deterministic_;
|
||||
};
|
||||
|
||||
class ParallelInterleaveDatasetOpTest : public DatasetOpsTestBase {};
|
||||
@ -134,7 +132,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -159,7 +157,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -184,7 +182,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -210,7 +208,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -235,7 +233,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_STRING},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -260,7 +258,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_STRING},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -285,7 +283,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_STRING},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -310,7 +308,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_STRING},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -335,7 +333,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_STRING},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -360,7 +358,7 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_STRING},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -383,7 +381,7 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_STRING},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -409,7 +407,7 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -435,7 +433,7 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -461,7 +459,7 @@ ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
|
@ -212,6 +212,21 @@ REGISTER_OP("ParallelInterleaveDatasetV2")
|
||||
.Attr("sloppy: bool = false")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ParallelInterleaveDatasetV3")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Input("cycle_length: int64")
|
||||
.Input("block_length: int64")
|
||||
.Input("num_parallel_calls: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
// "true", "false", or "default".
|
||||
.Attr("deterministic: string = 'default'")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("FilterDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -58,11 +59,13 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset, [list(range(i + 1, i + 11)) for i in range(0, 50, 10)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelInterleaveV2(self):
|
||||
def testParallelInterleave(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
parallel_interleave = "ParallelInterleaveV2"
|
||||
if compat.forward_compatible(2020, 2, 20):
|
||||
parallel_interleave = "ParallelInterleaveV3"
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next(
|
||||
["ParallelInterleaveV2", "Prefetch", "FiniteTake"]))
|
||||
testing.assert_next([parallel_interleave, "Prefetch", "FiniteTake"]))
|
||||
dataset = dataset.interleave(
|
||||
lambda x: dataset_ops.Dataset.from_tensors(x + 1),
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
@ -73,9 +76,12 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testChainedParallelDatasets(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
parallel_interleave = "ParallelInterleaveV2"
|
||||
if compat.forward_compatible(2020, 2, 20):
|
||||
parallel_interleave = "ParallelInterleaveV3"
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next([
|
||||
"ParallelMap", "Prefetch", "ParallelInterleaveV2", "Prefetch",
|
||||
"ParallelMap", "Prefetch", parallel_interleave, "Prefetch",
|
||||
"MapAndBatch", "Prefetch", "FiniteTake"
|
||||
]))
|
||||
dataset = dataset.map(
|
||||
|
@ -273,6 +273,7 @@ tf_py_test(
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
|
@ -23,12 +23,14 @@ import os
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -309,6 +311,55 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
interleave_fn, cycle_length=2, num_parallel_calls=2)
|
||||
self.assertDatasetProduces(dataset, list(range(5)))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
local_determinism=[None, True, False],
|
||||
global_determinism=[True, False])))
|
||||
def testDeterminismConfiguration(self, local_determinism, global_determinism):
|
||||
|
||||
def make_interleave_fn(delay_ms):
|
||||
|
||||
def interleave_fn(x):
|
||||
ds = dataset_ops.Dataset.from_tensors(x)
|
||||
if math_ops.equal(x, 0):
|
||||
ds = ds.apply(testing.sleep(delay_ms * 1000))
|
||||
else:
|
||||
ds = ds.apply(testing.sleep(0))
|
||||
return ds
|
||||
|
||||
return interleave_fn
|
||||
|
||||
expect_determinism = local_determinism or (local_determinism is None and
|
||||
global_determinism)
|
||||
if expect_determinism:
|
||||
delays_ms = [100]
|
||||
else:
|
||||
delays_ms = [10, 100, 1000, 20000]
|
||||
# We consider the test a success if it succeeds under any delay_ms. The
|
||||
# delay_ms needed to observe non-deterministic ordering varies across
|
||||
# test machines. Usually 10 or 100 milliseconds is enough, but on slow
|
||||
# machines it could take longer.
|
||||
for delay_ms in delays_ms:
|
||||
dataset = dataset_ops.Dataset.range(2)
|
||||
|
||||
dataset = dataset.interleave(
|
||||
make_interleave_fn(delay_ms),
|
||||
cycle_length=2,
|
||||
num_parallel_calls=2,
|
||||
deterministic=local_determinism)
|
||||
|
||||
opts = dataset_ops.Options()
|
||||
opts.experimental_deterministic = global_determinism
|
||||
dataset = dataset.with_options(opts)
|
||||
|
||||
expected = [0, 1] if expect_determinism else [1, 0]
|
||||
actual = self.getDatasetOutput(dataset)
|
||||
if actual == expected:
|
||||
return
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -30,6 +30,7 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import optimization_options
|
||||
from tensorflow.python.data.experimental.ops import stats_options
|
||||
@ -1621,7 +1622,8 @@ name=None))
|
||||
map_func,
|
||||
cycle_length=AUTOTUNE,
|
||||
block_length=1,
|
||||
num_parallel_calls=None):
|
||||
num_parallel_calls=None,
|
||||
deterministic=None):
|
||||
"""Maps `map_func` across this dataset, and interleaves the results.
|
||||
|
||||
For example, you can use `Dataset.interleave()` to process many input files
|
||||
@ -1669,9 +1671,21 @@ name=None))
|
||||
5, 5]
|
||||
|
||||
NOTE: The order of elements yielded by this transformation is
|
||||
deterministic, as long as `map_func` is a pure function. If
|
||||
`map_func` contains any stateful operations, the order in which
|
||||
that state is accessed is undefined.
|
||||
deterministic, as long as `map_func` is a pure function and
|
||||
`deterministic=True`. If `map_func` contains any stateful operations, the
|
||||
order in which that state is accessed is undefined.
|
||||
|
||||
Performance can often be improved by setting `num_parallel_calls` so that
|
||||
`interleave` will use multiple threads to fetch elements. If determinism
|
||||
isn't required, it can also improve performance to set
|
||||
`deterministic=False`.
|
||||
|
||||
>>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
|
||||
... "/var/data/file3.txt", "/var/data/file4.txt"]
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
|
||||
>>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x),
|
||||
... cycle_length=4, num_parallel_calls=tf.data.experimental.AUTOTUNE,
|
||||
... deterministic=False)
|
||||
|
||||
Args:
|
||||
map_func: A function mapping a dataset element to a dataset.
|
||||
@ -1688,6 +1702,12 @@ name=None))
|
||||
from cycle elements synchronously with no parallelism. If the value
|
||||
`tf.data.experimental.AUTOTUNE` is used, then the number of parallel
|
||||
calls is set dynamically based on available CPU.
|
||||
deterministic: (Optional.) A boolean controlling whether determinism
|
||||
should be traded for performance by allowing elements to be produced out
|
||||
of order. If `deterministic` is `None`, the
|
||||
`tf.data.Options.experimental_deterministic` dataset option (`True` by
|
||||
default) is used to decide whether to produce elements
|
||||
deterministically.
|
||||
|
||||
Returns:
|
||||
Dataset: A `Dataset`.
|
||||
@ -1696,7 +1716,8 @@ name=None))
|
||||
return InterleaveDataset(self, map_func, cycle_length, block_length)
|
||||
else:
|
||||
return ParallelInterleaveDataset(self, map_func, cycle_length,
|
||||
block_length, num_parallel_calls)
|
||||
block_length, num_parallel_calls,
|
||||
deterministic)
|
||||
|
||||
def filter(self, predicate):
|
||||
"""Filters this dataset according to `predicate`.
|
||||
@ -2334,9 +2355,11 @@ class DatasetV1(DatasetV2):
|
||||
map_func,
|
||||
cycle_length=AUTOTUNE,
|
||||
block_length=1,
|
||||
num_parallel_calls=None):
|
||||
return DatasetV1Adapter(super(DatasetV1, self).interleave(
|
||||
map_func, cycle_length, block_length, num_parallel_calls))
|
||||
num_parallel_calls=None,
|
||||
deterministic=None):
|
||||
return DatasetV1Adapter(
|
||||
super(DatasetV1, self).interleave(map_func, cycle_length, block_length,
|
||||
num_parallel_calls, deterministic))
|
||||
|
||||
@functools.wraps(DatasetV2.filter)
|
||||
def filter(self, predicate):
|
||||
@ -4016,8 +4039,13 @@ class InterleaveDataset(UnaryDataset):
|
||||
class ParallelInterleaveDataset(UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and interleaves the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, cycle_length, block_length,
|
||||
num_parallel_calls):
|
||||
def __init__(self,
|
||||
input_dataset,
|
||||
map_func,
|
||||
cycle_length,
|
||||
block_length,
|
||||
num_parallel_calls,
|
||||
deterministic=None):
|
||||
"""See `Dataset.interleave()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = StructuredFunctionWrapper(
|
||||
@ -4033,6 +4061,24 @@ class ParallelInterleaveDataset(UnaryDataset):
|
||||
block_length, dtype=dtypes.int64, name="block_length")
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
|
||||
if deterministic is None:
|
||||
deterministic_string = "default"
|
||||
elif deterministic:
|
||||
deterministic_string = "true"
|
||||
else:
|
||||
deterministic_string = "false"
|
||||
|
||||
if deterministic is not None or compat.forward_compatible(2020, 2, 20):
|
||||
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||
self._cycle_length,
|
||||
self._block_length,
|
||||
self._num_parallel_calls,
|
||||
f=self._map_func.function,
|
||||
deterministic=deterministic_string,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs, # pylint: disable=protected-access
|
||||
|
@ -79,7 +79,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -81,7 +81,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -81,7 +81,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -81,7 +81,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -81,7 +81,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -81,7 +81,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -81,7 +81,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -2632,6 +2632,10 @@ tf_module {
|
||||
name: "ParallelInterleaveDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'sloppy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelInterleaveDatasetV3"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelMapDataset"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
||||
|
@ -58,7 +58,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -59,7 +59,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -60,7 +60,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "interleave"
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'-1\', \'1\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "list_files"
|
||||
|
@ -2632,6 +2632,10 @@ tf_module {
|
||||
name: "ParallelInterleaveDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'sloppy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelInterleaveDatasetV3"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelMapDataset"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user