Enable op-level dataset determinism configuration for ParallelMap
Users can control determinism at a per-op level by specifying `deterministic` when calling map(). The `deterministic` argument takes higher priority than the `experimental_deterministic` dataset option. PiperOrigin-RevId: 294786773 Change-Id: If89f87dbe2adb51aad79791aa3f18072132e74c6
This commit is contained in:
parent
fc36231b87
commit
47940211fd
@ -0,0 +1,16 @@
|
||||
op {
|
||||
graph_op_name: "ParallelMapDatasetV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "num_parallel_calls"
|
||||
description: <<END
|
||||
The number of concurrent invocations of `f` that process
|
||||
elements from `input_dataset` in parallel.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
|
||||
description: <<END
|
||||
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
|
||||
to `num_parallel_calls` copies of `f` in parallel.
|
||||
END
|
||||
}
|
@ -57,7 +57,7 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||
"ZipDataset"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 28> kPassThroughOps = {
|
||||
constexpr std::array<const char*, 29> kPassThroughOps = {
|
||||
"_Retval",
|
||||
"AssertNextDataset",
|
||||
"BatchDataset",
|
||||
@ -75,6 +75,7 @@ constexpr std::array<const char*, 28> kPassThroughOps = {
|
||||
"ModelDataset",
|
||||
"OptimizeDataset",
|
||||
"ParallelMapDataset",
|
||||
"ParallelMapDatasetV2",
|
||||
"PrefetchDataset",
|
||||
"ReduceDataset",
|
||||
"RebatchDataset",
|
||||
|
@ -33,10 +33,11 @@ namespace {
|
||||
constexpr char kLegacyAutotune[] = "legacy_autotune";
|
||||
constexpr char kPrefetchDataset[] = "PrefetchDataset";
|
||||
|
||||
constexpr std::array<const char*, 6> kAsyncDatasetOps = {
|
||||
"ExperimentalMapAndBatchDataset", "ParallelMapDataset",
|
||||
constexpr std::array<const char*, 7> kAsyncDatasetOps = {
|
||||
"ExperimentalMapAndBatchDataset", "MapAndBatchDataset",
|
||||
"ParallelInterleaveDatasetV2", "ParallelInterleaveDatasetV3",
|
||||
"ParallelInterleaveDatasetV4", "MapAndBatchDataset",
|
||||
"ParallelInterleaveDatasetV4", "ParallelMapDataset",
|
||||
"ParallelMapDatasetV2",
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@ -32,9 +32,10 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
|
||||
"ParseExampleDataset",
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 2> kDeterministicAttrOps = {
|
||||
constexpr std::array<const char*, 3> kDeterministicAttrOps = {
|
||||
"ParallelInterleaveDatasetV3",
|
||||
"ParallelInterleaveDatasetV4",
|
||||
"ParallelMapDatasetV2",
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
|
@ -32,6 +32,12 @@ namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kFusedOpName[] = "MapAndBatchDataset";
|
||||
constexpr char kParallelMap[] = "ParallelMapDataset";
|
||||
constexpr char kParallelMapV2[] = "ParallelMapDatasetV2";
|
||||
|
||||
bool IsParallelMap(const NodeDef& node) {
|
||||
return node.op() == kParallelMap || node.op() == kParallelMapV2;
|
||||
}
|
||||
|
||||
NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
|
||||
MutableGraphView* graph) {
|
||||
@ -44,7 +50,7 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
|
||||
|
||||
// Set the `other_arguments` input arguments.
|
||||
int num_other_args;
|
||||
if (map_node.op() == "ParallelMapDataset") {
|
||||
if (IsParallelMap(map_node)) {
|
||||
num_other_args = map_node.input_size() - 2;
|
||||
} else {
|
||||
num_other_args = map_node.input_size() - 1;
|
||||
@ -57,7 +63,7 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
|
||||
new_node.add_input(batch_node.input(1));
|
||||
|
||||
// Set the `num_parallel_calls` input argument.
|
||||
if (map_node.op() == "ParallelMapDataset") {
|
||||
if (map_node.op() == kParallelMap) {
|
||||
// The type of the `num_parallel_calls` argument in ParallelMapDataset
|
||||
// and MapAndBatchDataset is different (int32 and int64 respectively)
|
||||
// so we cannot reuse the same Const node and thus create a new one.
|
||||
@ -115,7 +121,7 @@ Status MapAndBatchFusion::OptimizeAndCollectStats(Cluster* cluster,
|
||||
const NodeDef& batch_node = node;
|
||||
NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
|
||||
|
||||
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
|
||||
if (node2->op() != "MapDataset" && !IsParallelMap(*node2)) {
|
||||
continue;
|
||||
}
|
||||
// Use a more descriptive variable name now that we know the node type.
|
||||
|
@ -52,6 +52,7 @@ constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
|
||||
constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
|
||||
constexpr char kMapOp[] = "MapDataset";
|
||||
constexpr char kParallelMapOp[] = "ParallelMapDataset";
|
||||
constexpr char kParallelMapV2Op[] = "ParallelMapDatasetV2";
|
||||
constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset";
|
||||
constexpr char kPrefetchOp[] = "PrefetchDataset";
|
||||
|
||||
@ -253,7 +254,13 @@ Status AddNewMapNode(const NodeDef& old_map_node, const NodeDef& old_batch_node,
|
||||
const FunctionDef& vectorized_func,
|
||||
MutableGraphView* graph, NodeDef** new_map_node) {
|
||||
NodeDef map_node;
|
||||
map_node.set_op(old_map_node.op() == kMapOp ? kMapOp : kParallelMapOp);
|
||||
if (old_map_node.op() == kMapOp) {
|
||||
map_node.set_op(kMapOp);
|
||||
} else if (old_map_node.op() == kParallelMapOp) {
|
||||
map_node.set_op(kParallelMapOp);
|
||||
} else {
|
||||
map_node.set_op(kParallelMapV2Op);
|
||||
}
|
||||
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->graph(), &map_node);
|
||||
|
||||
// Set the `input_dataset` input argument
|
||||
@ -267,7 +274,7 @@ Status AddNewMapNode(const NodeDef& old_map_node, const NodeDef& old_batch_node,
|
||||
CopyInputs("other_arguments", input_map, old_map_node, &map_node));
|
||||
|
||||
// Set the `num_parallel_calls` input argument
|
||||
if (old_map_node.op() != kMapOp) {
|
||||
if (map_node.op() == kParallelMapOp) {
|
||||
// `num_parallel_calls` = `kAutotune`
|
||||
// TODO(rachelim): Evaluate the performance of other potential
|
||||
// transformations to `num_parallel_calls`,
|
||||
@ -275,6 +282,10 @@ Status AddNewMapNode(const NodeDef& old_map_node, const NodeDef& old_batch_node,
|
||||
auto autotune_val = graph_utils::AddScalarConstNode(
|
||||
static_cast<int32>(data::model::kAutotune), graph);
|
||||
map_node.add_input(autotune_val->name());
|
||||
} else if (map_node.op() == kParallelMapV2Op) {
|
||||
auto autotune_val =
|
||||
graph_utils::AddScalarConstNode(data::model::kAutotune, graph);
|
||||
map_node.add_input(autotune_val->name());
|
||||
}
|
||||
|
||||
// Set attrs
|
||||
@ -287,6 +298,12 @@ Status AddNewMapNode(const NodeDef& old_map_node, const NodeDef& old_batch_node,
|
||||
}
|
||||
|
||||
(*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
|
||||
if (old_map_node.attr().contains("sloppy")) {
|
||||
graph_utils::CopyAttribute("sloppy", old_map_node, &map_node);
|
||||
}
|
||||
if (old_map_node.attr().contains("deterministic")) {
|
||||
graph_utils::CopyAttribute("deterministic", old_map_node, &map_node);
|
||||
}
|
||||
*new_map_node = graph->AddNode(std::move(map_node));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -468,7 +485,8 @@ bool FindMapAndBatchPattern(const MutableGraphView& graph, const NodeDef& node,
|
||||
tmp_input_node = graph_utils::GetInputNode(*tmp_input_node, graph);
|
||||
}
|
||||
if (tmp_input_node->op() != kMapOp &&
|
||||
tmp_input_node->op() != kParallelMapOp) {
|
||||
tmp_input_node->op() != kParallelMapOp &&
|
||||
tmp_input_node->op() != kParallelMapV2Op) {
|
||||
return false;
|
||||
}
|
||||
map_node = tmp_input_node;
|
||||
|
@ -38,7 +38,7 @@ constexpr char kBatchOp[] = "BatchDataset";
|
||||
constexpr char kBatchV2Op[] = "BatchDatasetV2";
|
||||
constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
|
||||
constexpr char kMapOp[] = "MapDataset";
|
||||
constexpr char kParallelMapOp[] = "ParallelMapDataset";
|
||||
constexpr char kParallelMapOp[] = "ParallelMapDatasetV2";
|
||||
constexpr char kChooseFastestOp[] = "ChooseFastestBranchDataset";
|
||||
constexpr char kPrefetchOp[] = "PrefetchDataset";
|
||||
constexpr char kAttrNameF[] = "f";
|
||||
|
@ -483,6 +483,7 @@ tf_cc_test(
|
||||
":dataset_test_base",
|
||||
":dataset_utils",
|
||||
":iterator_ops",
|
||||
":name_utils",
|
||||
":parallel_map_dataset_op",
|
||||
":range_dataset_op",
|
||||
":stats_utils",
|
||||
|
@ -215,7 +215,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
absl::make_unique<ParseExampleFunctor>(this);
|
||||
return NewParallelMapIterator(
|
||||
{this, strings::StrCat(prefix, "::ParseExample")}, input_,
|
||||
std::move(parse_example_functor), num_parallel_calls_, sloppy_,
|
||||
std::move(parse_example_functor), num_parallel_calls_, !sloppy_,
|
||||
/*preserve_cardinality=*/true);
|
||||
}
|
||||
|
||||
|
@ -48,6 +48,7 @@ namespace data {
|
||||
/* static */ constexpr const char* const ParallelMapDatasetOp::kOutputShapes;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelMapDatasetOp::kUseInterOpParallelism;
|
||||
/* static */ constexpr const char* const ParallelMapDatasetOp::kDeterministic;
|
||||
/* static */ constexpr const char* const ParallelMapDatasetOp::kSloppy;
|
||||
/* static */ constexpr const char* const
|
||||
ParallelMapDatasetOp::kPreserveCardinality;
|
||||
@ -58,18 +59,20 @@ constexpr int kStatsReportingPeriodMillis = 1000;
|
||||
class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
int32 num_parallel_calls, const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes, bool sloppy,
|
||||
int64 num_parallel_calls, const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
DeterminismPolicy deterministic,
|
||||
std::unique_ptr<CapturedFunction> captured_func,
|
||||
bool preserve_cardinality)
|
||||
bool preserve_cardinality, int op_version)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
num_parallel_calls_(num_parallel_calls),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes),
|
||||
sloppy_(sloppy),
|
||||
deterministic_(deterministic),
|
||||
preserve_cardinality_(preserve_cardinality),
|
||||
captured_func_(std::move(captured_func)) {
|
||||
captured_func_(std::move(captured_func)),
|
||||
op_version_(op_version) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -79,10 +82,14 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
const string& prefix) const override {
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor =
|
||||
absl::make_unique<ParallelMapDatasetFunctor>(this);
|
||||
bool deterministic =
|
||||
deterministic_.IsDeterministic() || deterministic_.IsDefault();
|
||||
name_utils::IteratorPrefixParams params;
|
||||
params.op_version = op_version_;
|
||||
return NewParallelMapIterator(
|
||||
{this, name_utils::IteratorPrefix(kDatasetType, prefix)}, input_,
|
||||
std::move(parallel_map_functor), num_parallel_calls_, sloppy_,
|
||||
preserve_cardinality_);
|
||||
{this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
|
||||
input_, std::move(parallel_map_functor), num_parallel_calls_,
|
||||
deterministic, preserve_cardinality_);
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||
@ -92,7 +99,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return name_utils::DatasetDebugString(ParallelMapDatasetOp::kDatasetType);
|
||||
name_utils::DatasetDebugStringParams params;
|
||||
params.op_version = op_version_;
|
||||
return name_utils::DatasetDebugString(ParallelMapDatasetOp::kDatasetType,
|
||||
params);
|
||||
}
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
@ -118,41 +128,54 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
// Input: num_parallel_calls
|
||||
Node* num_parallel_calls = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(num_parallel_calls_, &num_parallel_calls));
|
||||
if (op_version_ == 1) {
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(static_cast<int32>(num_parallel_calls_),
|
||||
&num_parallel_calls));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(num_parallel_calls_, &num_parallel_calls));
|
||||
}
|
||||
std::vector<std::pair<StringPiece, AttrValue>> attrs;
|
||||
|
||||
// Attr: f
|
||||
AttrValue f_attr;
|
||||
b->BuildAttrValue(captured_func_->func(), &f_attr);
|
||||
attrs.emplace_back(kFunc, f_attr);
|
||||
|
||||
// Attr: Targuments
|
||||
AttrValue other_arguments_types_attr;
|
||||
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
|
||||
attrs.emplace_back(kTarguments, other_arguments_types_attr);
|
||||
|
||||
// Attr: use_inter_op_parallelism
|
||||
AttrValue use_inter_op_parallelism_attr;
|
||||
b->BuildAttrValue(captured_func_->use_inter_op_parallelism(),
|
||||
&use_inter_op_parallelism_attr);
|
||||
attrs.emplace_back(kUseInterOpParallelism, use_inter_op_parallelism_attr);
|
||||
|
||||
// Attr: sloppy
|
||||
AttrValue sloppy_attr;
|
||||
b->BuildAttrValue(sloppy_, &sloppy_attr);
|
||||
if (op_version_ == 1) {
|
||||
// Attr: sloppy
|
||||
AttrValue sloppy_attr;
|
||||
b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
|
||||
attrs.emplace_back(kSloppy, sloppy_attr);
|
||||
}
|
||||
if (op_version_ == 2) {
|
||||
AttrValue deterministic_attr;
|
||||
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
|
||||
attrs.emplace_back(kDeterministic, deterministic_attr);
|
||||
}
|
||||
|
||||
// Attr: preserve_cardinality
|
||||
AttrValue preserve_cardinality_attr;
|
||||
b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
|
||||
attrs.emplace_back(kPreserveCardinality, preserve_cardinality_attr);
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this,
|
||||
{std::make_pair(0, input_graph_node),
|
||||
std::make_pair(2, num_parallel_calls)}, // Single tensor inputs.
|
||||
{std::make_pair(1, other_arguments)}, // Tensor list inputs.
|
||||
{std::make_pair(kFunc, f_attr),
|
||||
std::make_pair(kTarguments, other_arguments_types_attr),
|
||||
std::make_pair(kUseInterOpParallelism, use_inter_op_parallelism_attr),
|
||||
std::make_pair(kSloppy, sloppy_attr),
|
||||
std::make_pair(kPreserveCardinality,
|
||||
preserve_cardinality_attr)}, // Attrs
|
||||
output));
|
||||
attrs, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -192,16 +215,17 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const int32 num_parallel_calls_;
|
||||
const int64 num_parallel_calls_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
const bool sloppy_;
|
||||
const DeterminismPolicy deterministic_;
|
||||
const bool preserve_cardinality_;
|
||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||
const int op_version_;
|
||||
};
|
||||
|
||||
ParallelMapDatasetOp::ParallelMapDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
: UnaryDatasetOpKernel(ctx), op_version_(ctx->HasAttr(kSloppy) ? 1 : 2) {
|
||||
FunctionMetadata::Params params;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseInterOpParallelism,
|
||||
¶ms.use_inter_op_parallelism));
|
||||
@ -210,16 +234,39 @@ ParallelMapDatasetOp::ParallelMapDatasetOp(OpKernelConstruction* ctx)
|
||||
FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy_));
|
||||
if (op_version_ == 1) {
|
||||
bool sloppy;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy));
|
||||
if (sloppy) {
|
||||
deterministic_ =
|
||||
DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
|
||||
} else {
|
||||
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
|
||||
}
|
||||
}
|
||||
if (op_version_ == 2) {
|
||||
std::string deterministic;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
|
||||
}
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
|
||||
}
|
||||
|
||||
void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) {
|
||||
int32 num_parallel_calls;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
|
||||
int64 num_parallel_calls;
|
||||
if (op_version_ == 1) {
|
||||
int32 parallel_calls;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseScalarArgument(ctx, kNumParallelCalls, ¶llel_calls));
|
||||
num_parallel_calls = parallel_calls;
|
||||
}
|
||||
if (op_version_ == 2) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
|
||||
}
|
||||
OP_REQUIRES(
|
||||
ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
|
||||
errors::InvalidArgument("num_parallel_calls must be greater than zero."));
|
||||
@ -235,7 +282,8 @@ void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, input, num_parallel_calls, output_types_, output_shapes_,
|
||||
sloppy_, std::move(captured_func), preserve_cardinality_);
|
||||
deterministic_, std::move(captured_func),
|
||||
preserve_cardinality_, op_version_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -250,15 +298,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
public:
|
||||
struct Params {
|
||||
Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
||||
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality)
|
||||
int64 num_parallel_calls, bool deterministic,
|
||||
bool preserve_cardinality)
|
||||
: parallel_map_functor(std::move(parallel_map_functor)),
|
||||
num_parallel_calls(num_parallel_calls),
|
||||
sloppy(sloppy),
|
||||
deterministic(deterministic),
|
||||
preserve_cardinality(preserve_cardinality) {}
|
||||
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
|
||||
int32 num_parallel_calls;
|
||||
bool sloppy;
|
||||
int64 num_parallel_calls;
|
||||
bool deterministic;
|
||||
bool preserve_cardinality;
|
||||
};
|
||||
|
||||
@ -271,7 +320,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
cond_var_(std::make_shared<condition_variable>()),
|
||||
num_parallel_calls_(std::make_shared<model::SharedState>(
|
||||
params.num_parallel_calls, mu_, cond_var_)),
|
||||
sloppy_(params.sloppy),
|
||||
deterministic_(params.deterministic),
|
||||
preserve_cardinality_(params.preserve_cardinality),
|
||||
autotune_(params.num_parallel_calls == model::kAutotune),
|
||||
key_prefix_(base_params.dataset->node_name()) {}
|
||||
@ -416,7 +465,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
data::TraceMeMetadata result;
|
||||
result.push_back(std::make_pair("autotune", autotune_ ? "true" : "false"));
|
||||
result.push_back(
|
||||
std::make_pair("deterministic", sloppy_ ? "false" : "true"));
|
||||
std::make_pair("deterministic", deterministic_ ? "true" : "false"));
|
||||
result.push_back(
|
||||
std::make_pair("parallelism", strings::Printf("%lld", parallelism)));
|
||||
return result;
|
||||
@ -564,7 +613,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
if (cancelled_) {
|
||||
return false;
|
||||
}
|
||||
if (sloppy_) {
|
||||
if (!deterministic_) {
|
||||
for (auto it = invocation_results_.begin();
|
||||
it != invocation_results_.end(); ++it) {
|
||||
if ((*it)->notification.HasBeenNotified() &&
|
||||
@ -663,8 +712,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
const std::shared_ptr<condition_variable> cond_var_;
|
||||
// Identifies the maximum number of parallel calls.
|
||||
const std::shared_ptr<model::SharedState> num_parallel_calls_;
|
||||
// Determines whether outputs can be produced in non-deterministic order.
|
||||
const bool sloppy_;
|
||||
// Whether outputs must be produced in deterministic order.
|
||||
const bool deterministic_;
|
||||
const bool preserve_cardinality_;
|
||||
const bool autotune_;
|
||||
const string key_prefix_;
|
||||
@ -689,18 +738,21 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
||||
const DatasetBaseIterator::BaseParams& params,
|
||||
const DatasetBase* input_dataset,
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
||||
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) {
|
||||
int64 num_parallel_calls, bool deterministic, bool preserve_cardinality) {
|
||||
return absl::make_unique<ParallelMapIterator>(
|
||||
params, input_dataset,
|
||||
ParallelMapIterator::Params{std::move(parallel_map_functor),
|
||||
num_parallel_calls, sloppy,
|
||||
num_parallel_calls, deterministic,
|
||||
preserve_cardinality});
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU),
|
||||
ParallelMapDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelMapDatasetV2").Device(DEVICE_CPU),
|
||||
ParallelMapDatasetOp);
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelMapDataset");
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelMapDatasetV2");
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -33,6 +34,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kUseInterOpParallelism =
|
||||
"use_inter_op_parallelism";
|
||||
static constexpr const char* const kDeterministic = "deterministic";
|
||||
static constexpr const char* const kSloppy = "sloppy";
|
||||
static constexpr const char* const kPreserveCardinality =
|
||||
"preserve_cardinality";
|
||||
@ -45,11 +47,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
const int op_version_;
|
||||
std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
bool sloppy_;
|
||||
bool preserve_cardinality_;
|
||||
DeterminismPolicy deterministic_;
|
||||
};
|
||||
|
||||
class ParallelMapFunctor {
|
||||
@ -78,7 +82,7 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
|
||||
const DatasetBaseIterator::BaseParams& params,
|
||||
const DatasetBase* input_dataset,
|
||||
std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
|
||||
int32 num_parallel_calls, bool sloppy, bool preserve_cardinality);
|
||||
int64 num_parallel_calls, bool deterministic, bool preserve_cardinality);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -12,26 +12,26 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
|
||||
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "parallel_map_dataset";
|
||||
constexpr int kOpVersion = 2;
|
||||
|
||||
class ParallelMapDatasetParams : public DatasetParams {
|
||||
public:
|
||||
template <typename T>
|
||||
ParallelMapDatasetParams(T input_dataset_params,
|
||||
std::vector<Tensor> other_arguments,
|
||||
int num_parallel_calls,
|
||||
FunctionDefHelper::AttrValueWrapper func,
|
||||
std::vector<FunctionDef> func_lib,
|
||||
DataTypeVector type_arguments,
|
||||
DataTypeVector output_dtypes,
|
||||
std::vector<PartialTensorShape> output_shapes,
|
||||
bool use_inter_op_parallelism, bool sloppy,
|
||||
bool preserve_cardinality, string node_name)
|
||||
ParallelMapDatasetParams(
|
||||
T input_dataset_params, std::vector<Tensor> other_arguments,
|
||||
int num_parallel_calls, FunctionDefHelper::AttrValueWrapper func,
|
||||
std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
|
||||
const DataTypeVector& output_dtypes,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
bool use_inter_op_parallelism, const std::string& deterministic,
|
||||
bool preserve_cardinality, string node_name)
|
||||
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
|
||||
std::move(node_name)),
|
||||
other_arguments_(std::move(other_arguments)),
|
||||
@ -40,18 +40,21 @@ class ParallelMapDatasetParams : public DatasetParams {
|
||||
func_lib_(std::move(func_lib)),
|
||||
type_arguments_(std::move(type_arguments)),
|
||||
use_inter_op_parallelism_(use_inter_op_parallelism),
|
||||
sloppy_(sloppy),
|
||||
deterministic_(deterministic),
|
||||
preserve_cardinality_(preserve_cardinality) {
|
||||
input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
|
||||
iterator_prefix_ =
|
||||
name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
|
||||
input_dataset_params.iterator_prefix());
|
||||
op_version_ = kOpVersion;
|
||||
name_utils::IteratorPrefixParams params;
|
||||
params.op_version = op_version_;
|
||||
iterator_prefix_ = name_utils::IteratorPrefix(
|
||||
input_dataset_params.dataset_type(),
|
||||
input_dataset_params.iterator_prefix(), params);
|
||||
}
|
||||
|
||||
std::vector<Tensor> GetInputTensors() const override {
|
||||
auto input_tensors = other_arguments_;
|
||||
input_tensors.emplace_back(
|
||||
CreateTensor<int32>(TensorShape({}), {num_parallel_calls_}));
|
||||
CreateTensor<int64>(TensorShape({}), {num_parallel_calls_}));
|
||||
return input_tensors;
|
||||
}
|
||||
|
||||
@ -73,7 +76,7 @@ class ParallelMapDatasetParams : public DatasetParams {
|
||||
{ParallelMapDatasetOp::kOutputTypes, output_dtypes_},
|
||||
{ParallelMapDatasetOp::kUseInterOpParallelism,
|
||||
use_inter_op_parallelism_},
|
||||
{ParallelMapDatasetOp::kSloppy, sloppy_},
|
||||
{ParallelMapDatasetOp::kDeterministic, deterministic_},
|
||||
{ParallelMapDatasetOp::kPreserveCardinality, preserve_cardinality_}};
|
||||
return Status::OK();
|
||||
}
|
||||
@ -91,7 +94,7 @@ class ParallelMapDatasetParams : public DatasetParams {
|
||||
std::vector<FunctionDef> func_lib_;
|
||||
DataTypeVector type_arguments_;
|
||||
bool use_inter_op_parallelism_;
|
||||
bool sloppy_;
|
||||
std::string deterministic_;
|
||||
bool preserve_cardinality_;
|
||||
};
|
||||
|
||||
@ -103,41 +106,43 @@ FunctionDefHelper::AttrValueWrapper MapFunc(const string& func_name,
|
||||
}
|
||||
|
||||
// test case 1: num_parallel_calls = 1, use_inter_op_parallelism = false,
|
||||
// sloppy = false, preserve_cardinality = false, MapFunc = XTimesTwo
|
||||
// deterministic = true, preserve_cardinality = false, MapFunc = XTimesTwo
|
||||
ParallelMapDatasetParams ParallelMapDatasetParams1() {
|
||||
return ParallelMapDatasetParams(RangeDatasetParams(0, 10, 3),
|
||||
/*other_arguments=*/{},
|
||||
/*num_parallel_calls=*/1,
|
||||
/*func=*/MapFunc("XTimesTwo", DT_INT64),
|
||||
/*func_lib*/ {test::function::XTimesTwo()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/false,
|
||||
/*sloppy=*/false,
|
||||
/*preserve_cardinality=*/false,
|
||||
/*node_name=*/kNodeName);
|
||||
return ParallelMapDatasetParams(
|
||||
RangeDatasetParams(0, 10, 3),
|
||||
/*other_arguments=*/{},
|
||||
/*num_parallel_calls=*/1,
|
||||
/*func=*/MapFunc("XTimesTwo", DT_INT64),
|
||||
/*func_lib*/ {test::function::XTimesTwo()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*preserve_cardinality=*/false,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 2: num_parallel_calls = 2, use_inter_op_parallelism = true,
|
||||
// sloppy = true, preserve_cardinality = true, MapFunc = XTimesTwo
|
||||
// deterministic = false, preserve_cardinality = true, MapFunc = XTimesTwo
|
||||
ParallelMapDatasetParams ParallelMapDatasetParams2() {
|
||||
return ParallelMapDatasetParams(RangeDatasetParams(0, 10, 3),
|
||||
/*other_arguments=*/{},
|
||||
/*num_parallel_calls=*/2,
|
||||
/*func=*/MapFunc("XTimesTwo", DT_INT64),
|
||||
/*func_lib*/ {test::function::XTimesTwo()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/true,
|
||||
/*sloppy=*/true,
|
||||
/*preserve_cardinality=*/true,
|
||||
/*node_name=*/kNodeName);
|
||||
return ParallelMapDatasetParams(
|
||||
RangeDatasetParams(0, 10, 3),
|
||||
/*other_arguments=*/{},
|
||||
/*num_parallel_calls=*/2,
|
||||
/*func=*/MapFunc("XTimesTwo", DT_INT64),
|
||||
/*func_lib*/ {test::function::XTimesTwo()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*preserve_cardinality=*/true,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 3: num_parallel_calls = 3, use_inter_op_parallelism = true,
|
||||
// sloppy = false, preserve_cardinality = false, MapFunc = XTimesFour
|
||||
// deterministic = true, preserve_cardinality = false, MapFunc = XTimesFour
|
||||
ParallelMapDatasetParams ParallelMapDatasetParams3() {
|
||||
return ParallelMapDatasetParams(
|
||||
RangeDatasetParams(0, 10, 3),
|
||||
@ -149,30 +154,31 @@ ParallelMapDatasetParams ParallelMapDatasetParams3() {
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/true,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*preserve_cardinality=*/false,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 4: num_parallel_calls = 4, use_inter_op_parallelism = false,
|
||||
// sloppy = false, preserve_cardinality = false, MapFunc = XTimesTwo
|
||||
// deterministic = true, preserve_cardinality = false, MapFunc = XTimesTwo
|
||||
ParallelMapDatasetParams ParallelMapDatasetParams4() {
|
||||
return ParallelMapDatasetParams(RangeDatasetParams(0, 10, 3),
|
||||
/*other_arguments=*/{},
|
||||
/*num_parallel_calls=*/4,
|
||||
/*func=*/MapFunc("XTimesTwo", DT_INT64),
|
||||
/*func_lib*/ {test::function::XTimesTwo()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/false,
|
||||
/*sloppy=*/false,
|
||||
/*preserve_cardinality=*/false,
|
||||
/*node_name=*/kNodeName);
|
||||
return ParallelMapDatasetParams(
|
||||
RangeDatasetParams(0, 10, 3),
|
||||
/*other_arguments=*/{},
|
||||
/*num_parallel_calls=*/4,
|
||||
/*func=*/MapFunc("XTimesTwo", DT_INT64),
|
||||
/*func_lib*/ {test::function::XTimesTwo()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*preserve_cardinality=*/false,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 5: num_parallel_calls = kAutotune, use_inter_op_parallelism = true,
|
||||
// sloppy = true, preserve_cardinality = true, MapFunc = XTimesFour
|
||||
// deterministic = false, preserve_cardinality = true, MapFunc = XTimesFour
|
||||
ParallelMapDatasetParams ParallelMapDatasetParams5() {
|
||||
return ParallelMapDatasetParams(
|
||||
RangeDatasetParams(0, 10, 3),
|
||||
@ -184,13 +190,13 @@ ParallelMapDatasetParams ParallelMapDatasetParams5() {
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/true,
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*preserve_cardinality=*/true,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// test case 6: num_parallel_calls = 4, use_inter_op_parallelism = true,
|
||||
// sloppy = false, preserve_cardinality = false, MapFunc = XTimesFour
|
||||
// deterministic = true, preserve_cardinality = false, MapFunc = XTimesFour
|
||||
ParallelMapDatasetParams ParallelMapDatasetParams6() {
|
||||
return ParallelMapDatasetParams(
|
||||
RangeDatasetParams(0, 10, 3),
|
||||
@ -202,14 +208,14 @@ ParallelMapDatasetParams ParallelMapDatasetParams6() {
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/true,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*preserve_cardinality=*/false,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// TODO(feihugis): make this test case work.
|
||||
// test case 7: num_parallel_calls = 2, use_inter_op_parallelism = false,
|
||||
// sloppy = false, preserve_cardinality = false, MapFunc = XTimesFour
|
||||
// deterministic = true, preserve_cardinality = false, MapFunc = XTimesFour
|
||||
ParallelMapDatasetParams ParallelMapDatasetParams7() {
|
||||
return ParallelMapDatasetParams(
|
||||
RangeDatasetParams(0, 10, 3),
|
||||
@ -221,14 +227,15 @@ ParallelMapDatasetParams ParallelMapDatasetParams7() {
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/false,
|
||||
/*sloppy=*/false,
|
||||
/*deterministic=*/DeterminismPolicy::kDeterministic,
|
||||
/*preserve_cardinality=*/false,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// TODO(feihugis): make this test case work.
|
||||
// test case 8: num_parallel_calls = kAutotune, use_inter_op_parallelism =
|
||||
// false, sloppy = true, preserve_cardinality = true, MapFunc = XTimesFour
|
||||
// false, deterministic = false, preserve_cardinality = true, MapFunc =
|
||||
// XTimesFour
|
||||
ParallelMapDatasetParams ParallelMapDatasetParams8() {
|
||||
return ParallelMapDatasetParams(
|
||||
RangeDatasetParams(0, 10, 3),
|
||||
@ -240,24 +247,25 @@ ParallelMapDatasetParams ParallelMapDatasetParams8() {
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/false,
|
||||
/*sloppy=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*preserve_cardinality=*/true,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
ParallelMapDatasetParams ParallelMapDatasetParamsWithInvalidNumParallelCalls() {
|
||||
return ParallelMapDatasetParams(RangeDatasetParams(0, 10, 3),
|
||||
/*other_arguments=*/{},
|
||||
/*num_parallel_calls=*/-4,
|
||||
/*func=*/MapFunc("XTimesTwo", DT_INT64),
|
||||
/*func_lib*/ {test::function::XTimesTwo()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/true,
|
||||
/*sloppy=*/true,
|
||||
/*preserve_cardinality=*/true,
|
||||
/*node_name=*/kNodeName);
|
||||
return ParallelMapDatasetParams(
|
||||
RangeDatasetParams(0, 10, 3),
|
||||
/*other_arguments=*/{},
|
||||
/*num_parallel_calls=*/-4,
|
||||
/*func=*/MapFunc("XTimesTwo", DT_INT64),
|
||||
/*func_lib*/ {test::function::XTimesTwo()},
|
||||
/*type_arguments=*/{},
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({})},
|
||||
/*use_inter_op_parallelism=*/true,
|
||||
/*deterministic=*/DeterminismPolicy::kNondeterministic,
|
||||
/*preserve_cardinality=*/true,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
std::vector<GetNextTestCase<ParallelMapDatasetParams>> GetNextTestCases() {
|
||||
@ -300,8 +308,10 @@ TEST_F(ParallelMapDatasetOpTest, DatasetNodeName) {
|
||||
TEST_F(ParallelMapDatasetOpTest, DatasetTypeString) {
|
||||
auto dataset_params = ParallelMapDatasetParams1();
|
||||
TF_ASSERT_OK(Initialize(dataset_params));
|
||||
name_utils::OpNameParams params;
|
||||
params.op_version = dataset_params.op_version();
|
||||
TF_ASSERT_OK(CheckDatasetTypeString(
|
||||
name_utils::OpName(ParallelMapDatasetOp::kDatasetType)));
|
||||
name_utils::OpName(ParallelMapDatasetOp::kDatasetType, params)));
|
||||
}
|
||||
|
||||
TEST_F(ParallelMapDatasetOpTest, DatasetOutputDtypes) {
|
||||
@ -350,8 +360,11 @@ TEST_F(ParallelMapDatasetOpTest, IteratorOutputShapes) {
|
||||
TEST_F(ParallelMapDatasetOpTest, IteratorPrefix) {
|
||||
auto dataset_params = ParallelMapDatasetParams1();
|
||||
TF_ASSERT_OK(Initialize(dataset_params));
|
||||
TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
|
||||
ParallelMapDatasetOp::kDatasetType, dataset_params.iterator_prefix())));
|
||||
name_utils::IteratorPrefixParams params;
|
||||
params.op_version = dataset_params.op_version();
|
||||
TF_ASSERT_OK(CheckIteratorPrefix(
|
||||
name_utils::IteratorPrefix(ParallelMapDatasetOp::kDatasetType,
|
||||
dataset_params.iterator_prefix(), params)));
|
||||
}
|
||||
|
||||
std::vector<IteratorSaveAndRestoreTestCase<ParallelMapDatasetParams>>
|
||||
|
@ -161,6 +161,21 @@ REGISTER_OP("ParallelMapDataset")
|
||||
.Attr("preserve_cardinality: bool = false")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ParallelMapDatasetV2")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Input("num_parallel_calls: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("use_inter_op_parallelism: bool = true")
|
||||
// "true", "false", or "default".
|
||||
.Attr("deterministic: string = 'default'")
|
||||
.Attr("preserve_cardinality: bool = false")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("PrefetchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("buffer_size: int64")
|
||||
|
@ -227,7 +227,7 @@ tf_py_test(
|
||||
name = "map_vectorization_test",
|
||||
size = "small",
|
||||
srcs = ["map_vectorization_test.py"],
|
||||
shard_count = 8,
|
||||
shard_count = 16,
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
|
@ -37,8 +37,11 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testParallelMap(self):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
parallel_map = "ParallelMap"
|
||||
if compat.forward_compatible(2020, 2, 20):
|
||||
parallel_map = "ParallelMapV2"
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next(["ParallelMap", "Prefetch", "FiniteTake"]))
|
||||
testing.assert_next([parallel_map, "Prefetch", "FiniteTake"]))
|
||||
dataset = dataset.map(
|
||||
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.take(50)
|
||||
@ -83,9 +86,12 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
parallel_interleave = "ParallelInterleaveV3"
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
parallel_interleave = "ParallelInterleaveV4"
|
||||
parallel_map = "ParallelMap"
|
||||
if compat.forward_compatible(2020, 2, 20):
|
||||
parallel_map = "ParallelMapV2"
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next([
|
||||
"ParallelMap", "Prefetch", parallel_interleave, "Prefetch",
|
||||
parallel_map, "Prefetch", parallel_interleave, "Prefetch",
|
||||
"MapAndBatch", "Prefetch", "FiniteTake"
|
||||
]))
|
||||
dataset = dataset.map(
|
||||
|
@ -18,12 +18,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
@ -42,6 +44,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -217,7 +220,11 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
Returns:
|
||||
Tuple of (unoptimized dataset, optimized dataset).
|
||||
"""
|
||||
map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
|
||||
map_node_name = "Map"
|
||||
if num_parallel_calls is not None:
|
||||
map_node_name = "ParallelMap"
|
||||
if compat.forward_compatible(2020, 2, 20):
|
||||
map_node_name = "ParallelMapV2"
|
||||
|
||||
def _make_dataset(node_names):
|
||||
dataset = base_dataset.apply(testing.assert_next(node_names))
|
||||
@ -514,20 +521,14 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def map_fn(x):
|
||||
return x * 2
|
||||
|
||||
unoptimized_seq = []
|
||||
|
||||
def make_apply_fn(is_fused):
|
||||
if is_fused:
|
||||
unoptimized_seq.append("MapAndBatch")
|
||||
|
||||
def apply_fn(dataset):
|
||||
return dataset.apply(
|
||||
batching.map_and_batch(map_fn, 2, 12, drop_remainder=True))
|
||||
|
||||
return apply_fn
|
||||
else:
|
||||
unoptimized_seq.extend(["ParallelMap", "BatchV2"])
|
||||
|
||||
def apply_fn(dataset):
|
||||
return dataset.map(map_fn, 12).batch(2, drop_remainder=True)
|
||||
|
||||
@ -541,17 +542,60 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
apply_fn_1 = make_apply_fn(fuse_first)
|
||||
apply_fn_2 = make_apply_fn(fuse_second)
|
||||
|
||||
def make_dataset(node_names):
|
||||
dataset = base_dataset.apply(testing.assert_next(node_names))
|
||||
def make_dataset():
|
||||
dataset = base_dataset
|
||||
dataset = apply_fn_1(dataset)
|
||||
dataset = apply_fn_2(dataset)
|
||||
return dataset
|
||||
|
||||
unoptimized = make_dataset(unoptimized_seq)
|
||||
optimized = make_dataset(["ChooseFastestBranch", "ChooseFastestBranch"])
|
||||
unoptimized = make_dataset()
|
||||
optimized = make_dataset()
|
||||
optimized = self._enable_map_vectorization(optimized)
|
||||
self.assertDatasetsEqual(optimized, unoptimized)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
local_determinism=[True, False, None],
|
||||
global_determinism=[True, False])))
|
||||
def testOptimizationDeterminism(self, local_determinism, global_determinism):
|
||||
# Tests that vectorization maintains the determinism setting.
|
||||
expect_determinism = local_determinism or (local_determinism is None and
|
||||
global_determinism)
|
||||
elements = list(range(1000))
|
||||
|
||||
def dataset_fn(delay_ms):
|
||||
|
||||
def sleep(x):
|
||||
time.sleep(delay_ms / 1000)
|
||||
return x
|
||||
|
||||
def map_function(x):
|
||||
if math_ops.equal(x, 0):
|
||||
return check_ops.ensure_shape(
|
||||
script_ops.py_func(sleep, [x], x.dtype, stateful=False), ())
|
||||
else:
|
||||
return x
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||
dataset = dataset.map(
|
||||
map_function, num_parallel_calls=10, deterministic=local_determinism)
|
||||
dataset = dataset.batch(1)
|
||||
|
||||
opts = dataset_ops.Options()
|
||||
opts.experimental_deterministic = global_determinism
|
||||
# Prevent the map/batch from being rewritten as MapAndBatch.
|
||||
opts.experimental_optimization.apply_default_optimizations = False
|
||||
dataset = dataset.with_options(opts)
|
||||
dataset = self._enable_map_vectorization(dataset)
|
||||
return dataset
|
||||
|
||||
self.checkDeterminism(
|
||||
dataset_fn,
|
||||
expect_determinism,
|
||||
expected_elements=[[element] for element in elements])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationIgnoreStateful(self):
|
||||
|
||||
|
@ -318,8 +318,11 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
local_determinism=[None, True, False],
|
||||
global_determinism=[True, False])))
|
||||
def testDeterminismConfiguration(self, local_determinism, global_determinism):
|
||||
expect_determinism = local_determinism or (local_determinism is None and
|
||||
global_determinism)
|
||||
elements = list(range(1000))
|
||||
|
||||
def make_interleave_fn(delay_ms):
|
||||
def dataset_fn(delay_ms):
|
||||
|
||||
def interleave_fn(x):
|
||||
ds = dataset_ops.Dataset.from_tensors(x)
|
||||
@ -329,36 +332,18 @@ class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = ds.apply(testing.sleep(0))
|
||||
return ds
|
||||
|
||||
return interleave_fn
|
||||
|
||||
expect_determinism = local_determinism or (local_determinism is None and
|
||||
global_determinism)
|
||||
if expect_determinism:
|
||||
delays_ms = [100]
|
||||
else:
|
||||
delays_ms = [10, 100, 1000, 20000]
|
||||
# We consider the test a success if it succeeds under any delay_ms. The
|
||||
# delay_ms needed to observe non-deterministic ordering varies across
|
||||
# test machines. Usually 10 or 100 milliseconds is enough, but on slow
|
||||
# machines it could take longer.
|
||||
for delay_ms in delays_ms:
|
||||
dataset = dataset_ops.Dataset.range(2)
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||
dataset = dataset.interleave(
|
||||
make_interleave_fn(delay_ms),
|
||||
cycle_length=2,
|
||||
num_parallel_calls=2,
|
||||
interleave_fn,
|
||||
cycle_length=10,
|
||||
num_parallel_calls=10,
|
||||
deterministic=local_determinism)
|
||||
|
||||
opts = dataset_ops.Options()
|
||||
opts.experimental_deterministic = global_determinism
|
||||
dataset = dataset.with_options(opts)
|
||||
return dataset
|
||||
|
||||
expected = [0, 1] if expect_determinism else [1, 0]
|
||||
actual = self.getDatasetOutput(dataset)
|
||||
if actual == expected:
|
||||
return
|
||||
self.assertEqual(expected, actual)
|
||||
self.checkDeterminism(dataset_fn, expect_determinism, elements)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1323,6 +1323,44 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output=expected_output,
|
||||
requires_initialization=True)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
_test_combinations(),
|
||||
combinations.combine(
|
||||
local_determinism=[None, True, False],
|
||||
global_determinism=[True, False])))
|
||||
def testDeterminismConfiguration(self, apply_map, local_determinism,
|
||||
global_determinism):
|
||||
expect_determinism = local_determinism or (local_determinism is None and
|
||||
global_determinism)
|
||||
elements = list(range(1000))
|
||||
|
||||
def dataset_fn(delay_ms):
|
||||
|
||||
def sleep(x):
|
||||
time.sleep(delay_ms / 1000)
|
||||
return x
|
||||
|
||||
def map_function(x):
|
||||
if math_ops.equal(x, 0):
|
||||
return script_ops.py_func(sleep, [x], x.dtype)
|
||||
else:
|
||||
return x
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(elements)
|
||||
dataset = apply_map(
|
||||
dataset,
|
||||
map_function,
|
||||
num_parallel_calls=2,
|
||||
deterministic=local_determinism)
|
||||
opts = dataset_ops.Options()
|
||||
opts.experimental_deterministic = global_determinism
|
||||
dataset = dataset.with_options(opts)
|
||||
return dataset
|
||||
|
||||
self.checkDeterminism(
|
||||
dataset_fn, expect_determinism, expected_elements=elements)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -297,3 +297,34 @@ class DatasetTestBase(test.TestCase):
|
||||
self.structuredElement(substructure, shape, dtype)
|
||||
for substructure in element_structure
|
||||
])
|
||||
|
||||
def checkDeterminism(self, dataset_fn, expect_determinism, expected_elements):
|
||||
"""Tests whether a dataset produces its elements deterministically.
|
||||
|
||||
`dataset_fn` takes a delay_ms argument, which tells it how long to delay
|
||||
production of the first dataset element. This gives us a way to trigger
|
||||
out-of-order production of dataset elements.
|
||||
|
||||
Args:
|
||||
dataset_fn: A function taking a delay_ms argument.
|
||||
expect_determinism: Whether to expect deterministic ordering.
|
||||
expected_elements: The elements expected to be produced by the dataset,
|
||||
assuming the dataset produces elements in deterministic order.
|
||||
"""
|
||||
if expect_determinism:
|
||||
dataset = dataset_fn(100)
|
||||
actual = self.getDatasetOutput(dataset)
|
||||
self.assertAllEqual(expected_elements, actual)
|
||||
return
|
||||
|
||||
# We consider the test a success if it succeeds under any delay_ms. The
|
||||
# delay_ms needed to observe non-deterministic ordering varies across
|
||||
# test machines. Usually 10 or 100 milliseconds is enough, but on slow
|
||||
# machines it could take longer.
|
||||
for delay_ms in [10, 100, 1000, 20000]:
|
||||
dataset = dataset_fn(delay_ms)
|
||||
actual = self.getDatasetOutput(dataset)
|
||||
self.assertCountEqual(expected_elements, actual)
|
||||
if actual[0] != expected_elements[0]:
|
||||
return
|
||||
self.fail("Failed to observe nondeterministic ordering")
|
||||
|
@ -1483,7 +1483,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values,
|
||||
drop_remainder)
|
||||
|
||||
def map(self, map_func, num_parallel_calls=None):
|
||||
def map(self, map_func, num_parallel_calls=None, deterministic=None):
|
||||
"""Maps `map_func` across the elements of this dataset.
|
||||
|
||||
This transformation applies `map_func` to each element of this dataset, and
|
||||
@ -1576,6 +1576,16 @@ name=None))
|
||||
>>> list(d.as_numpy_iterator())
|
||||
[b'HELLO', b'WORLD']
|
||||
|
||||
Performance can often be improved by setting `num_parallel_calls` so that
|
||||
`map` will use multiple threads to process elements. If deterministic order
|
||||
isn't required, it can also improve performance to set
|
||||
`deterministic=False`.
|
||||
|
||||
>>> dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
|
||||
>>> dataset = dataset.map(lambda x: x + 1,
|
||||
... num_parallel_calls=tf.data.experimental.AUTOTUNE,
|
||||
... deterministic=False)
|
||||
|
||||
Args:
|
||||
map_func: A function mapping a dataset element to another dataset element.
|
||||
num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
|
||||
@ -1583,6 +1593,12 @@ name=None))
|
||||
If not specified, elements will be processed sequentially. If the value
|
||||
`tf.data.experimental.AUTOTUNE` is used, then the number of parallel
|
||||
calls is set dynamically based on available CPU.
|
||||
deterministic: (Optional.) A boolean controlling whether determinism
|
||||
should be traded for performance by allowing elements to be produced out
|
||||
of order. If `deterministic` is `None`, the
|
||||
`tf.data.Options.experimental_deterministic` dataset option (`True` by
|
||||
default) is used to decide whether to produce elements
|
||||
deterministically.
|
||||
|
||||
Returns:
|
||||
Dataset: A `Dataset`.
|
||||
@ -1591,7 +1607,11 @@ name=None))
|
||||
return MapDataset(self, map_func, preserve_cardinality=True)
|
||||
else:
|
||||
return ParallelMapDataset(
|
||||
self, map_func, num_parallel_calls, preserve_cardinality=True)
|
||||
self,
|
||||
map_func,
|
||||
num_parallel_calls,
|
||||
deterministic,
|
||||
preserve_cardinality=True)
|
||||
|
||||
def flat_map(self, map_func):
|
||||
"""Maps `map_func` across this dataset and flattens the result.
|
||||
@ -2299,21 +2319,29 @@ class DatasetV1(DatasetV2):
|
||||
padded_shapes=None,
|
||||
padding_values=None,
|
||||
drop_remainder=False):
|
||||
return DatasetV1Adapter(super(DatasetV1, self).padded_batch(
|
||||
batch_size, padded_shapes, padding_values, drop_remainder))
|
||||
return DatasetV1Adapter(
|
||||
super(DatasetV1, self).padded_batch(batch_size, padded_shapes,
|
||||
padding_values, drop_remainder))
|
||||
|
||||
@functools.wraps(DatasetV2.map)
|
||||
def map(self, map_func, num_parallel_calls=None):
|
||||
def map(self, map_func, num_parallel_calls=None, deterministic=None):
|
||||
if num_parallel_calls is None:
|
||||
return DatasetV1Adapter(
|
||||
MapDataset(self, map_func, preserve_cardinality=False))
|
||||
else:
|
||||
return DatasetV1Adapter(
|
||||
ParallelMapDataset(
|
||||
self, map_func, num_parallel_calls, preserve_cardinality=False))
|
||||
self,
|
||||
map_func,
|
||||
num_parallel_calls,
|
||||
deterministic,
|
||||
preserve_cardinality=False))
|
||||
|
||||
@deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
|
||||
def map_with_legacy_function(self, map_func, num_parallel_calls=None):
|
||||
def map_with_legacy_function(self,
|
||||
map_func,
|
||||
num_parallel_calls=None,
|
||||
deterministic=None):
|
||||
"""Maps `map_func` across the elements of this dataset.
|
||||
|
||||
NOTE: This is an escape hatch for existing uses of `map` that do not work
|
||||
@ -2329,6 +2357,12 @@ class DatasetV1(DatasetV2):
|
||||
If not specified, elements will be processed sequentially. If the value
|
||||
`tf.data.experimental.AUTOTUNE` is used, then the number of parallel
|
||||
calls is set dynamically based on available CPU.
|
||||
deterministic: (Optional.) A boolean controlling whether determinism
|
||||
should be traded for performance by allowing elements to be produced out
|
||||
of order. If `deterministic` is `None`, the
|
||||
`tf.data.Options.experimental_deterministic` dataset option (`True` by
|
||||
default) is used to decide whether to produce elements
|
||||
deterministically.
|
||||
|
||||
Returns:
|
||||
Dataset: A `Dataset`.
|
||||
@ -2346,6 +2380,7 @@ class DatasetV1(DatasetV2):
|
||||
self,
|
||||
map_func,
|
||||
num_parallel_calls,
|
||||
deterministic,
|
||||
preserve_cardinality=False,
|
||||
use_legacy_function=True))
|
||||
|
||||
@ -3933,6 +3968,7 @@ class ParallelMapDataset(UnaryDataset):
|
||||
input_dataset,
|
||||
map_func,
|
||||
num_parallel_calls,
|
||||
deterministic,
|
||||
use_inter_op_parallelism=True,
|
||||
preserve_cardinality=False,
|
||||
use_legacy_function=False):
|
||||
@ -3944,17 +3980,36 @@ class ParallelMapDataset(UnaryDataset):
|
||||
self._transformation_name(),
|
||||
dataset=input_dataset,
|
||||
use_legacy_function=use_legacy_function)
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
|
||||
if deterministic is None:
|
||||
self._deterministic = "default"
|
||||
elif deterministic:
|
||||
self._deterministic = "true"
|
||||
else:
|
||||
self._deterministic = "false"
|
||||
self._preserve_cardinality = preserve_cardinality
|
||||
variant_tensor = gen_dataset_ops.parallel_map_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
num_parallel_calls=self._num_parallel_calls,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
preserve_cardinality=self._preserve_cardinality,
|
||||
**self._flat_structure)
|
||||
if deterministic is not None or compat.forward_compatible(2020, 3, 6):
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
|
||||
variant_tensor = gen_dataset_ops.parallel_map_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
num_parallel_calls=self._num_parallel_calls,
|
||||
deterministic=self._deterministic,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
preserve_cardinality=self._preserve_cardinality,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
self._num_parallel_calls = ops.convert_to_tensor(
|
||||
num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
|
||||
variant_tensor = gen_dataset_ops.parallel_map_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
num_parallel_calls=self._num_parallel_calls,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
preserve_cardinality=self._preserve_cardinality,
|
||||
**self._flat_structure)
|
||||
super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
|
@ -95,11 +95,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_with_legacy_function"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -97,11 +97,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_with_legacy_function"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -97,11 +97,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_with_legacy_function"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -97,11 +97,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_with_legacy_function"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -97,11 +97,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_with_legacy_function"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -97,11 +97,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_with_legacy_function"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -97,11 +97,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "map_with_legacy_function"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -2644,6 +2644,10 @@ tf_module {
|
||||
name: "ParallelMapDataset"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelMapDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParameterizedTruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'means\', \'stdevs\', \'minvals\', \'maxvals\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||
|
@ -66,7 +66,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -68,7 +68,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -67,7 +67,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -68,7 +68,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -68,7 +68,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -68,7 +68,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -68,7 +68,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "map"
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "options"
|
||||
|
@ -2644,6 +2644,10 @@ tf_module {
|
||||
name: "ParallelMapDataset"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParallelMapDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParameterizedTruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'means\', \'stdevs\', \'minvals\', \'maxvals\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user