Enable op-level dataset determinism configuration for ParseExampleDataset
Users can control determinism at a per-op level by specifying `deterministic` when calling parse_example_dataset(). The `deterministic` argument takes higher priority than the `experimental_deterministic` dataset option. PiperOrigin-RevId: 294957589 Change-Id: If974e5fdd403be99a214dd352b4095048afb4701
This commit is contained in:
parent
a7e2b5b416
commit
e9a013dd75
@ -0,0 +1,80 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ParseExampleDatasetV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "dense_defaults"
|
||||||
|
description: <<END
|
||||||
|
A dict mapping string keys to `Tensor`s.
|
||||||
|
The keys of the dict must match the dense_keys of the feature.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "sparse_keys"
|
||||||
|
description: <<END
|
||||||
|
A list of string keys in the examples features.
|
||||||
|
The results for these keys will be returned as `SparseTensor` objects.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dense_keys"
|
||||||
|
description: <<END
|
||||||
|
A list of Ndense string Tensors (scalars).
|
||||||
|
The keys expected in the Examples features associated with dense values.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "sparse_types"
|
||||||
|
description: <<END
|
||||||
|
A list of `DTypes` of the same length as `sparse_keys`.
|
||||||
|
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
|
||||||
|
and `tf.string` (`BytesList`) are supported.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tdense"
|
||||||
|
description: <<END
|
||||||
|
A list of DTypes of the same length as `dense_keys`.
|
||||||
|
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
|
||||||
|
and `tf.string` (`BytesList`) are supported.
|
||||||
|
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dense_shapes"
|
||||||
|
description: <<END
|
||||||
|
List of tuples with the same length as `dense_keys`.
|
||||||
|
The shape of the data for each dense feature referenced by `dense_keys`.
|
||||||
|
Required for any input tensors identified by `dense_keys`. Must be
|
||||||
|
either fully defined, or may contain an unknown first dimension.
|
||||||
|
An unknown first dimension means the feature is treated as having
|
||||||
|
a variable number of blocks, and the output shape along this dimension
|
||||||
|
is considered unknown at graph build time. Padding is applied for
|
||||||
|
minibatch elements smaller than the maximum number of blocks for the
|
||||||
|
given feature along this dimension.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "output_types"
|
||||||
|
description: <<END
|
||||||
|
The type list for the return values.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "output_shapes"
|
||||||
|
description: <<END
|
||||||
|
The list of shapes being produced.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "deterministic"
|
||||||
|
description: <<END
|
||||||
|
A string indicating the op-level determinism to use. Deterministic controls
|
||||||
|
whether the dataset is allowed to return elements out of order if the next
|
||||||
|
element to be returned isn't available, but a later element is. Options are
|
||||||
|
"true", "false", and "default". "default" indicates that determinism should be
|
||||||
|
decided by the `experimental_deterministic` parameter of `tf.data.Options`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features."
|
||||||
|
}
|
||||||
|
|
@ -32,10 +32,11 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
|
|||||||
"ParseExampleDataset",
|
"ParseExampleDataset",
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr std::array<const char*, 3> kDeterministicAttrOps = {
|
constexpr std::array<const char*, 4> kDeterministicAttrOps = {
|
||||||
"ParallelInterleaveDatasetV3",
|
"ParallelInterleaveDatasetV3",
|
||||||
"ParallelInterleaveDatasetV4",
|
"ParallelInterleaveDatasetV4",
|
||||||
"ParallelMapDatasetV2",
|
"ParallelMapDatasetV2",
|
||||||
|
"ParseExampleDatasetV2",
|
||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
@ -282,6 +282,8 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:functional_ops_op_lib",
|
"//tensorflow/core:functional_ops_op_lib",
|
||||||
|
"//tensorflow/core/kernels/data:dataset_utils",
|
||||||
|
"//tensorflow/core/kernels/data:name_utils",
|
||||||
"//tensorflow/core/kernels/data:parallel_map_dataset_op",
|
"//tensorflow/core/kernels/data:parallel_map_dataset_op",
|
||||||
"//tensorflow/core/kernels/data:stats_utils",
|
"//tensorflow/core/kernels/data:stats_utils",
|
||||||
],
|
],
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/metrics.h"
|
#include "tensorflow/core/common_runtime/metrics.h"
|
||||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
|
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||||
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
|
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
|
||||||
#include "tensorflow/core/kernels/data/stats_utils.h"
|
#include "tensorflow/core/kernels/data/stats_utils.h"
|
||||||
#include "tensorflow/core/util/example_proto_fast_parsing.h"
|
#include "tensorflow/core/util/example_proto_fast_parsing.h"
|
||||||
@ -30,7 +32,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit ParseExampleDatasetOp(OpKernelConstruction* ctx)
|
explicit ParseExampleDatasetOp(OpKernelConstruction* ctx)
|
||||||
: UnaryDatasetOpKernel(ctx),
|
: UnaryDatasetOpKernel(ctx),
|
||||||
graph_def_version_(ctx->graph_def_version()) {
|
graph_def_version_(ctx->graph_def_version()),
|
||||||
|
op_version_(ctx->HasAttr("deterministic") ? 2 : 1) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_keys", &sparse_keys_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_keys", &sparse_keys_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_keys", &dense_keys_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_keys", &dense_keys_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_));
|
||||||
@ -38,7 +41,24 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("sloppy", &sloppy_));
|
|
||||||
|
if (op_version_ == 1) {
|
||||||
|
bool sloppy;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("sloppy", &sloppy));
|
||||||
|
if (sloppy) {
|
||||||
|
deterministic_ =
|
||||||
|
DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
|
||||||
|
} else {
|
||||||
|
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (op_version_ == 2) {
|
||||||
|
std::string deterministic;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("deterministic", &deterministic));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
|
||||||
|
}
|
||||||
|
|
||||||
has_ragged_keys_ = ctx->HasAttr("ragged_keys");
|
has_ragged_keys_ = ctx->HasAttr("ragged_keys");
|
||||||
if (has_ragged_keys_) {
|
if (has_ragged_keys_) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ragged_keys", &ragged_keys_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("ragged_keys", &ragged_keys_));
|
||||||
@ -162,12 +182,12 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
it->second = i++;
|
it->second = i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
*output = new Dataset(ctx, input, dense_defaults, sparse_keys_, dense_keys_,
|
*output = new Dataset(
|
||||||
std::move(key_to_output_index), std::move(config),
|
ctx, input, dense_defaults, sparse_keys_, dense_keys_,
|
||||||
num_parallel_calls, sparse_types_, dense_types_,
|
std::move(key_to_output_index), std::move(config), num_parallel_calls,
|
||||||
dense_shapes_, output_types_, output_shapes_, sloppy_,
|
sparse_types_, dense_types_, dense_shapes_, output_types_,
|
||||||
has_ragged_keys_, ragged_keys_, ragged_value_types_,
|
output_shapes_, deterministic_, has_ragged_keys_, ragged_keys_,
|
||||||
ragged_split_types_);
|
ragged_value_types_, ragged_split_types_, op_version_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -182,10 +202,11 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
const DataTypeVector& dense_types,
|
const DataTypeVector& dense_types,
|
||||||
const std::vector<PartialTensorShape>& dense_shapes,
|
const std::vector<PartialTensorShape>& dense_shapes,
|
||||||
const DataTypeVector& output_types,
|
const DataTypeVector& output_types,
|
||||||
const std::vector<PartialTensorShape>& output_shapes, bool sloppy,
|
const std::vector<PartialTensorShape>& output_shapes,
|
||||||
bool has_ragged_keys, std::vector<string> ragged_keys,
|
const DeterminismPolicy& deterministic, bool has_ragged_keys,
|
||||||
|
std::vector<string> ragged_keys,
|
||||||
const DataTypeVector& ragged_value_types,
|
const DataTypeVector& ragged_value_types,
|
||||||
const DataTypeVector& ragged_split_types)
|
const DataTypeVector& ragged_split_types, int op_version)
|
||||||
: DatasetBase(DatasetContext(ctx)),
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
input_(input),
|
input_(input),
|
||||||
dense_defaults_(std::move(dense_defaults)),
|
dense_defaults_(std::move(dense_defaults)),
|
||||||
@ -202,8 +223,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
dense_shapes_(dense_shapes),
|
dense_shapes_(dense_shapes),
|
||||||
output_types_(output_types),
|
output_types_(output_types),
|
||||||
output_shapes_(output_shapes),
|
output_shapes_(output_shapes),
|
||||||
sloppy_(sloppy),
|
deterministic_(deterministic),
|
||||||
has_ragged_keys_(has_ragged_keys) {
|
has_ragged_keys_(has_ragged_keys),
|
||||||
|
op_version_(op_version) {
|
||||||
input_->Ref();
|
input_->Ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,9 +235,14 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
const string& prefix) const override {
|
const string& prefix) const override {
|
||||||
std::unique_ptr<ParallelMapFunctor> parse_example_functor =
|
std::unique_ptr<ParallelMapFunctor> parse_example_functor =
|
||||||
absl::make_unique<ParseExampleFunctor>(this);
|
absl::make_unique<ParseExampleFunctor>(this);
|
||||||
|
name_utils::IteratorPrefixParams params;
|
||||||
|
params.op_version = op_version_;
|
||||||
|
bool deterministic =
|
||||||
|
deterministic_.IsDeterministic() || deterministic_.IsDefault();
|
||||||
return NewParallelMapIterator(
|
return NewParallelMapIterator(
|
||||||
{this, strings::StrCat(prefix, "::ParseExample")}, input_,
|
{this, name_utils::IteratorPrefix("ParseExample", prefix, params)},
|
||||||
std::move(parse_example_functor), num_parallel_calls_, !sloppy_,
|
input_, std::move(parse_example_functor), num_parallel_calls_,
|
||||||
|
deterministic,
|
||||||
/*preserve_cardinality=*/true);
|
/*preserve_cardinality=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,7 +255,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string DebugString() const override {
|
string DebugString() const override {
|
||||||
return "ParseExampleDatasetOp::Dataset";
|
name_utils::DatasetDebugStringParams params;
|
||||||
|
params.op_version = op_version_;
|
||||||
|
return name_utils::DatasetDebugString("ParseExampleDataset", params);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||||
@ -257,60 +286,60 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
dense_defaults_nodes.emplace_back(node);
|
dense_defaults_nodes.emplace_back(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
AttrValue sparse_keys_attr;
|
std::vector<std::pair<StringPiece, AttrValue>> attrs;
|
||||||
AttrValue dense_keys_attr;
|
|
||||||
AttrValue sparse_types_attr;
|
|
||||||
AttrValue dense_attr;
|
|
||||||
AttrValue dense_shapes_attr;
|
|
||||||
AttrValue sloppy_attr;
|
|
||||||
|
|
||||||
|
AttrValue sparse_keys_attr;
|
||||||
b->BuildAttrValue(sparse_keys_, &sparse_keys_attr);
|
b->BuildAttrValue(sparse_keys_, &sparse_keys_attr);
|
||||||
|
attrs.emplace_back("sparse_keys", sparse_keys_attr);
|
||||||
|
|
||||||
|
AttrValue dense_keys_attr;
|
||||||
b->BuildAttrValue(dense_keys_, &dense_keys_attr);
|
b->BuildAttrValue(dense_keys_, &dense_keys_attr);
|
||||||
|
attrs.emplace_back("dense_keys", dense_keys_attr);
|
||||||
|
|
||||||
|
AttrValue sparse_types_attr;
|
||||||
b->BuildAttrValue(sparse_types_, &sparse_types_attr);
|
b->BuildAttrValue(sparse_types_, &sparse_types_attr);
|
||||||
|
attrs.emplace_back("sparse_types", sparse_types_attr);
|
||||||
|
|
||||||
|
AttrValue dense_attr;
|
||||||
b->BuildAttrValue(dense_types_, &dense_attr);
|
b->BuildAttrValue(dense_types_, &dense_attr);
|
||||||
|
attrs.emplace_back("Tdense", dense_attr);
|
||||||
|
|
||||||
|
AttrValue dense_shapes_attr;
|
||||||
b->BuildAttrValue(dense_shapes_, &dense_shapes_attr);
|
b->BuildAttrValue(dense_shapes_, &dense_shapes_attr);
|
||||||
b->BuildAttrValue(sloppy_, &sloppy_attr);
|
attrs.emplace_back("dense_shapes", dense_shapes_attr);
|
||||||
|
|
||||||
|
if (op_version_ == 1) {
|
||||||
|
AttrValue sloppy_attr;
|
||||||
|
b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
|
||||||
|
attrs.emplace_back("sloppy", sloppy_attr);
|
||||||
|
}
|
||||||
|
if (op_version_ == 2) {
|
||||||
|
AttrValue deterministic_attr;
|
||||||
|
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
|
||||||
|
attrs.emplace_back("deterministic", deterministic_attr);
|
||||||
|
}
|
||||||
|
|
||||||
if (has_ragged_keys_) {
|
if (has_ragged_keys_) {
|
||||||
AttrValue ragged_keys_attr;
|
AttrValue ragged_keys_attr;
|
||||||
AttrValue ragged_value_types_attr;
|
|
||||||
AttrValue ragged_split_types_attr;
|
|
||||||
b->BuildAttrValue(ragged_keys_, &ragged_keys_attr);
|
b->BuildAttrValue(ragged_keys_, &ragged_keys_attr);
|
||||||
b->BuildAttrValue(ragged_value_types_, &ragged_value_types_attr);
|
attrs.emplace_back("ragged_keys", ragged_keys_attr);
|
||||||
b->BuildAttrValue(ragged_split_types_, &ragged_split_types_attr);
|
|
||||||
|
AttrValue ragged_value_types_attr;
|
||||||
|
b->BuildAttrValue(ragged_value_types_, &ragged_value_types_attr);
|
||||||
|
attrs.emplace_back("ragged_value_types", ragged_value_types_attr);
|
||||||
|
|
||||||
|
AttrValue ragged_split_types_attr;
|
||||||
|
b->BuildAttrValue(ragged_split_types_, &ragged_split_types_attr);
|
||||||
|
attrs.emplace_back("ragged_split_types", ragged_split_types_attr);
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
b->AddDataset(this,
|
|
||||||
{
|
|
||||||
{0, input_graph_node},
|
|
||||||
{1, num_parallle_calls_node},
|
|
||||||
},
|
|
||||||
{{2, dense_defaults_nodes}},
|
|
||||||
{{"sparse_keys", sparse_keys_attr},
|
|
||||||
{"dense_keys", dense_keys_attr},
|
|
||||||
{"sparse_types", sparse_types_attr},
|
|
||||||
{"Tdense", dense_attr},
|
|
||||||
{"dense_shapes", dense_shapes_attr},
|
|
||||||
{"sloppy", sloppy_attr},
|
|
||||||
{"ragged_keys", ragged_keys_attr},
|
|
||||||
{"ragged_value_types", ragged_value_types_attr},
|
|
||||||
{"ragged_split_types", ragged_split_types_attr}},
|
|
||||||
output));
|
|
||||||
} else {
|
|
||||||
TF_RETURN_IF_ERROR(b->AddDataset(this,
|
TF_RETURN_IF_ERROR(b->AddDataset(this,
|
||||||
{
|
{
|
||||||
{0, input_graph_node},
|
{0, input_graph_node},
|
||||||
{1, num_parallle_calls_node},
|
{1, num_parallle_calls_node},
|
||||||
},
|
},
|
||||||
{{2, dense_defaults_nodes}},
|
{{2, dense_defaults_nodes}}, attrs,
|
||||||
{{"sparse_keys", sparse_keys_attr},
|
|
||||||
{"dense_keys", dense_keys_attr},
|
|
||||||
{"sparse_types", sparse_types_attr},
|
|
||||||
{"Tdense", dense_attr},
|
|
||||||
{"dense_shapes", dense_shapes_attr},
|
|
||||||
{"sloppy", sloppy_attr}},
|
|
||||||
output));
|
output));
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -443,14 +472,15 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
const std::vector<PartialTensorShape> dense_shapes_;
|
const std::vector<PartialTensorShape> dense_shapes_;
|
||||||
const DataTypeVector output_types_;
|
const DataTypeVector output_types_;
|
||||||
const std::vector<PartialTensorShape> output_shapes_;
|
const std::vector<PartialTensorShape> output_shapes_;
|
||||||
const bool sloppy_;
|
const DeterminismPolicy deterministic_;
|
||||||
const bool has_ragged_keys_;
|
const bool has_ragged_keys_;
|
||||||
|
const int op_version_;
|
||||||
};
|
};
|
||||||
|
|
||||||
const int graph_def_version_;
|
const int graph_def_version_;
|
||||||
DataTypeVector output_types_;
|
DataTypeVector output_types_;
|
||||||
std::vector<PartialTensorShape> output_shapes_;
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
bool sloppy_;
|
DeterminismPolicy deterministic_;
|
||||||
std::vector<string> sparse_keys_;
|
std::vector<string> sparse_keys_;
|
||||||
std::vector<string> dense_keys_;
|
std::vector<string> dense_keys_;
|
||||||
std::vector<string> ragged_keys_;
|
std::vector<string> ragged_keys_;
|
||||||
@ -462,10 +492,13 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
std::vector<bool> variable_length_;
|
std::vector<bool> variable_length_;
|
||||||
std::vector<std::size_t> elements_per_stride_;
|
std::vector<std::size_t> elements_per_stride_;
|
||||||
bool has_ragged_keys_;
|
bool has_ragged_keys_;
|
||||||
|
const int op_version_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU),
|
||||||
ParseExampleDatasetOp);
|
ParseExampleDatasetOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ParseExampleDatasetV2").Device(DEVICE_CPU),
|
||||||
|
ParseExampleDatasetOp);
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("ExperimentalParseExampleDataset").Device(DEVICE_CPU),
|
Name("ExperimentalParseExampleDataset").Device(DEVICE_CPU),
|
||||||
ParseExampleDatasetOp);
|
ParseExampleDatasetOp);
|
||||||
|
@ -596,6 +596,27 @@ REGISTER_OP("ParseExampleDataset")
|
|||||||
.Attr("ragged_split_types: list({int32,int64}) >= 0 = []")
|
.Attr("ragged_split_types: list({int32,int64}) >= 0 = []")
|
||||||
.SetShapeFn(shape_inference::ScalarShape);
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
|
REGISTER_OP("ParseExampleDatasetV2")
|
||||||
|
.Input("input_dataset: variant")
|
||||||
|
.Input("num_parallel_calls: int64")
|
||||||
|
.Input("dense_defaults: Tdense")
|
||||||
|
.Output("handle: variant")
|
||||||
|
.Attr("sparse_keys: list(string) >= 0")
|
||||||
|
.Attr("dense_keys: list(string) >= 0")
|
||||||
|
.Attr("sparse_types: list({float,int64,string}) >= 0")
|
||||||
|
.Attr("Tdense: list({float,int64,string}) >= 0")
|
||||||
|
.Attr("dense_shapes: list(shape) >= 0")
|
||||||
|
.Attr("output_types: list(type) >= 1")
|
||||||
|
.Attr("output_shapes: list(shape) >= 1") // Output components will be
|
||||||
|
// sorted by key (dense_keys and
|
||||||
|
// sparse_keys combined) here.
|
||||||
|
// "true", "false", or "default".
|
||||||
|
.Attr("deterministic: string = 'default'")
|
||||||
|
.Attr("ragged_keys: list(string) >= 0 = []")
|
||||||
|
.Attr("ragged_value_types: list({float,int64,string}) >= 0 = []")
|
||||||
|
.Attr("ragged_split_types: list({int32,int64}) >= 0 = []")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
REGISTER_OP("ExperimentalParseExampleDataset")
|
REGISTER_OP("ExperimentalParseExampleDataset")
|
||||||
.Input("input_dataset: variant")
|
.Input("input_dataset: variant")
|
||||||
.Input("num_parallel_calls: int64")
|
.Input("num_parallel_calls: int64")
|
||||||
|
@ -476,6 +476,7 @@ tf_py_test(
|
|||||||
name = "parse_example_dataset_test",
|
name = "parse_example_dataset_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["parse_example_dataset_test.py"],
|
srcs = ["parse_example_dataset_test.py"],
|
||||||
|
shard_count = 4,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -1110,6 +1110,43 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase,
|
|||||||
expected_values=expected_output,
|
expected_values=expected_output,
|
||||||
create_iterator_twice=True)
|
create_iterator_twice=True)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(
|
||||||
|
test_base.default_test_combinations(),
|
||||||
|
combinations.combine(
|
||||||
|
local_determinism=[None, True, False],
|
||||||
|
global_determinism=[True, False])))
|
||||||
|
def testDeterminism(self, local_determinism, global_determinism):
|
||||||
|
num_elements = 1000
|
||||||
|
batches = []
|
||||||
|
for i in range(num_elements):
|
||||||
|
example_i = example(features=features({
|
||||||
|
"a": int64_feature([i]),
|
||||||
|
}))
|
||||||
|
batches.append([example_i.SerializeToString()])
|
||||||
|
|
||||||
|
test_features = {"a": parsing_ops.FixedLenFeature((), dtype=dtypes.int64)}
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices(batches)
|
||||||
|
dataset = dataset.apply(
|
||||||
|
contrib_parsing_ops.parse_example_dataset(
|
||||||
|
test_features,
|
||||||
|
num_parallel_calls=10,
|
||||||
|
deterministic=local_determinism))
|
||||||
|
|
||||||
|
opts = dataset_ops.Options()
|
||||||
|
opts.experimental_deterministic = global_determinism
|
||||||
|
dataset = dataset.with_options(opts)
|
||||||
|
|
||||||
|
expected = list(range(num_elements))
|
||||||
|
actual = [elem["a"][0] for elem in self.getDatasetOutput(dataset)]
|
||||||
|
|
||||||
|
require_order = local_determinism or (local_determinism is None and
|
||||||
|
global_determinism)
|
||||||
|
if require_order:
|
||||||
|
self.assertAllEqual(expected, actual)
|
||||||
|
else:
|
||||||
|
self.assertCountEqual(expected, actual)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.util import structure
|
from tensorflow.python.data.util import structure
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -31,13 +32,20 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||||
"""A `Dataset` that parses `example` dataset into a `dict` dataset."""
|
"""A `Dataset` that parses `example` dataset into a `dict` dataset."""
|
||||||
|
|
||||||
def __init__(self, input_dataset, features, num_parallel_calls):
|
def __init__(self, input_dataset, features, num_parallel_calls,
|
||||||
|
deterministic):
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
if not structure.are_compatible(
|
if not structure.are_compatible(
|
||||||
input_dataset.element_spec,
|
input_dataset.element_spec,
|
||||||
tensor_spec.TensorSpec([None], dtypes.string)):
|
tensor_spec.TensorSpec([None], dtypes.string)):
|
||||||
raise TypeError("Input dataset should be a dataset of vectors of strings")
|
raise TypeError("Input dataset should be a dataset of vectors of strings")
|
||||||
self._num_parallel_calls = num_parallel_calls
|
self._num_parallel_calls = num_parallel_calls
|
||||||
|
if deterministic is None:
|
||||||
|
self._deterministic = "default"
|
||||||
|
elif deterministic:
|
||||||
|
self._deterministic = "true"
|
||||||
|
else:
|
||||||
|
self._deterministic = "false"
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self._features = parsing_ops._prepend_none_dimension(features)
|
self._features = parsing_ops._prepend_none_dimension(features)
|
||||||
# TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature
|
# TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature
|
||||||
@ -77,6 +85,22 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
|||||||
self._element_spec[key] = ragged_tensor.RaggedTensorSpec(
|
self._element_spec[key] = ragged_tensor.RaggedTensorSpec(
|
||||||
input_dataset_shape.concatenate([None]), value_type, 1, splits_type)
|
input_dataset_shape.concatenate([None]), value_type, 1, splits_type)
|
||||||
|
|
||||||
|
if deterministic is not None or compat.forward_compatible(2020, 3, 6):
|
||||||
|
variant_tensor = (
|
||||||
|
gen_experimental_dataset_ops.parse_example_dataset_v2(
|
||||||
|
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
|
self._num_parallel_calls,
|
||||||
|
self._dense_defaults,
|
||||||
|
self._sparse_keys,
|
||||||
|
self._dense_keys,
|
||||||
|
self._sparse_types,
|
||||||
|
self._dense_shapes,
|
||||||
|
deterministic=self._deterministic,
|
||||||
|
ragged_keys=self._ragged_keys,
|
||||||
|
ragged_value_types=self._ragged_value_types,
|
||||||
|
ragged_split_types=self._ragged_split_types,
|
||||||
|
**self._flat_structure))
|
||||||
|
else:
|
||||||
variant_tensor = (
|
variant_tensor = (
|
||||||
gen_experimental_dataset_ops.parse_example_dataset(
|
gen_experimental_dataset_ops.parse_example_dataset(
|
||||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
@ -99,7 +123,7 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
|||||||
|
|
||||||
# TODO(b/111553342): add arguments names and example names as well.
|
# TODO(b/111553342): add arguments names and example names as well.
|
||||||
@tf_export("data.experimental.parse_example_dataset")
|
@tf_export("data.experimental.parse_example_dataset")
|
||||||
def parse_example_dataset(features, num_parallel_calls=1):
|
def parse_example_dataset(features, num_parallel_calls=1, deterministic=None):
|
||||||
"""A transformation that parses `Example` protos into a `dict` of tensors.
|
"""A transformation that parses `Example` protos into a `dict` of tensors.
|
||||||
|
|
||||||
Parses a number of serialized `Example` protos given in `serialized`. We refer
|
Parses a number of serialized `Example` protos given in `serialized`. We refer
|
||||||
@ -119,6 +143,13 @@ def parse_example_dataset(features, num_parallel_calls=1):
|
|||||||
`VarLenFeature`, `RaggedFeature`, and `SparseFeature` values.
|
`VarLenFeature`, `RaggedFeature`, and `SparseFeature` values.
|
||||||
num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
|
num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
|
||||||
representing the number of parsing processes to call in parallel.
|
representing the number of parsing processes to call in parallel.
|
||||||
|
deterministic: (Optional.) A boolean controlling whether determinism
|
||||||
|
should be traded for performance by allowing elements to be produced out
|
||||||
|
of order if some parsing calls complete faster than others. If
|
||||||
|
`deterministic` is `None`, the
|
||||||
|
`tf.data.Options.experimental_deterministic` dataset option (`True` by
|
||||||
|
default) is used to decide whether to produce elements
|
||||||
|
deterministically.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dataset transformation function, which can be passed to
|
A dataset transformation function, which can be passed to
|
||||||
@ -132,7 +163,8 @@ def parse_example_dataset(features, num_parallel_calls=1):
|
|||||||
|
|
||||||
def _apply_fn(dataset):
|
def _apply_fn(dataset):
|
||||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
||||||
out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
|
out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls,
|
||||||
|
deterministic)
|
||||||
if any(
|
if any(
|
||||||
isinstance(feature, parsing_ops.SparseFeature) or
|
isinstance(feature, parsing_ops.SparseFeature) or
|
||||||
(isinstance(feature, parsing_ops.RaggedFeature) and feature.partitions)
|
(isinstance(feature, parsing_ops.RaggedFeature) and feature.partitions)
|
||||||
|
@ -190,7 +190,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "parse_example_dataset"
|
name: "parse_example_dataset"
|
||||||
argspec: "args=[\'features\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
argspec: "args=[\'features\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "prefetch_to_device"
|
name: "prefetch_to_device"
|
||||||
|
@ -2660,6 +2660,10 @@ tf_module {
|
|||||||
name: "ParseExampleDataset"
|
name: "ParseExampleDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'sloppy\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'sloppy\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ParseExampleDatasetV2"
|
||||||
|
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'deterministic\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ParseExampleV2"
|
name: "ParseExampleV2"
|
||||||
argspec: "args=[\'serialized\', \'names\', \'sparse_keys\', \'dense_keys\', \'ragged_keys\', \'dense_defaults\', \'num_sparse\', \'sparse_types\', \'ragged_value_types\', \'ragged_split_types\', \'dense_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'serialized\', \'names\', \'sparse_keys\', \'dense_keys\', \'ragged_keys\', \'dense_defaults\', \'num_sparse\', \'sparse_types\', \'ragged_value_types\', \'ragged_split_types\', \'dense_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -158,7 +158,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "parse_example_dataset"
|
name: "parse_example_dataset"
|
||||||
argspec: "args=[\'features\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
argspec: "args=[\'features\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "prefetch_to_device"
|
name: "prefetch_to_device"
|
||||||
|
@ -2660,6 +2660,10 @@ tf_module {
|
|||||||
name: "ParseExampleDataset"
|
name: "ParseExampleDataset"
|
||||||
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'sloppy\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'sloppy\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ParseExampleDatasetV2"
|
||||||
|
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'deterministic\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ParseExampleV2"
|
name: "ParseExampleV2"
|
||||||
argspec: "args=[\'serialized\', \'names\', \'sparse_keys\', \'dense_keys\', \'ragged_keys\', \'dense_defaults\', \'num_sparse\', \'sparse_types\', \'ragged_value_types\', \'ragged_split_types\', \'dense_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'serialized\', \'names\', \'sparse_keys\', \'dense_keys\', \'ragged_keys\', \'dense_defaults\', \'num_sparse\', \'sparse_types\', \'ragged_value_types\', \'ragged_split_types\', \'dense_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user