Enable op-level dataset determinism configuration for ParseExampleDataset
Users can control determinism at a per-op level by specifying `deterministic` when calling parse_example_dataset(). The `deterministic` argument takes higher priority than the `experimental_deterministic` dataset option. PiperOrigin-RevId: 294957589 Change-Id: If974e5fdd403be99a214dd352b4095048afb4701
This commit is contained in:
parent
a7e2b5b416
commit
e9a013dd75
@ -0,0 +1,80 @@
|
||||
op {
|
||||
graph_op_name: "ParseExampleDatasetV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "dense_defaults"
|
||||
description: <<END
|
||||
A dict mapping string keys to `Tensor`s.
|
||||
The keys of the dict must match the dense_keys of the feature.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "sparse_keys"
|
||||
description: <<END
|
||||
A list of string keys in the examples features.
|
||||
The results for these keys will be returned as `SparseTensor` objects.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dense_keys"
|
||||
description: <<END
|
||||
A list of Ndense string Tensors (scalars).
|
||||
The keys expected in the Examples features associated with dense values.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "sparse_types"
|
||||
description: <<END
|
||||
A list of `DTypes` of the same length as `sparse_keys`.
|
||||
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
|
||||
and `tf.string` (`BytesList`) are supported.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "Tdense"
|
||||
description: <<END
|
||||
A list of DTypes of the same length as `dense_keys`.
|
||||
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
|
||||
and `tf.string` (`BytesList`) are supported.
|
||||
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dense_shapes"
|
||||
description: <<END
|
||||
List of tuples with the same length as `dense_keys`.
|
||||
The shape of the data for each dense feature referenced by `dense_keys`.
|
||||
Required for any input tensors identified by `dense_keys`. Must be
|
||||
either fully defined, or may contain an unknown first dimension.
|
||||
An unknown first dimension means the feature is treated as having
|
||||
a variable number of blocks, and the output shape along this dimension
|
||||
is considered unknown at graph build time. Padding is applied for
|
||||
minibatch elements smaller than the maximum number of blocks for the
|
||||
given feature along this dimension.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
description: <<END
|
||||
The type list for the return values.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
description: <<END
|
||||
The list of shapes being produced.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "deterministic"
|
||||
description: <<END
|
||||
A string indicating the op-level determinism to use. Deterministic controls
|
||||
whether the dataset is allowed to return elements out of order if the next
|
||||
element to be returned isn't available, but a later element is. Options are
|
||||
"true", "false", and "default". "default" indicates that determinism should be
|
||||
decided by the `experimental_deterministic` parameter of `tf.data.Options`.
|
||||
END
|
||||
}
|
||||
summary: "Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features."
|
||||
}
|
||||
|
@ -32,10 +32,11 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
|
||||
"ParseExampleDataset",
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 3> kDeterministicAttrOps = {
|
||||
constexpr std::array<const char*, 4> kDeterministicAttrOps = {
|
||||
"ParallelInterleaveDatasetV3",
|
||||
"ParallelInterleaveDatasetV4",
|
||||
"ParallelMapDatasetV2",
|
||||
"ParseExampleDatasetV2",
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
|
@ -282,6 +282,8 @@ tf_kernel_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
"//tensorflow/core/kernels/data:parallel_map_dataset_op",
|
||||
"//tensorflow/core/kernels/data:stats_utils",
|
||||
],
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/stats_utils.h"
|
||||
#include "tensorflow/core/util/example_proto_fast_parsing.h"
|
||||
@ -30,7 +32,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit ParseExampleDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx),
|
||||
graph_def_version_(ctx->graph_def_version()) {
|
||||
graph_def_version_(ctx->graph_def_version()),
|
||||
op_version_(ctx->HasAttr("deterministic") ? 2 : 1) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_keys", &sparse_keys_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_keys", &dense_keys_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_));
|
||||
@ -38,7 +41,24 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("sloppy", &sloppy_));
|
||||
|
||||
if (op_version_ == 1) {
|
||||
bool sloppy;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("sloppy", &sloppy));
|
||||
if (sloppy) {
|
||||
deterministic_ =
|
||||
DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
|
||||
} else {
|
||||
deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
|
||||
}
|
||||
}
|
||||
if (op_version_ == 2) {
|
||||
std::string deterministic;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("deterministic", &deterministic));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
|
||||
}
|
||||
|
||||
has_ragged_keys_ = ctx->HasAttr("ragged_keys");
|
||||
if (has_ragged_keys_) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ragged_keys", &ragged_keys_));
|
||||
@ -162,12 +182,12 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
it->second = i++;
|
||||
}
|
||||
|
||||
*output = new Dataset(ctx, input, dense_defaults, sparse_keys_, dense_keys_,
|
||||
std::move(key_to_output_index), std::move(config),
|
||||
num_parallel_calls, sparse_types_, dense_types_,
|
||||
dense_shapes_, output_types_, output_shapes_, sloppy_,
|
||||
has_ragged_keys_, ragged_keys_, ragged_value_types_,
|
||||
ragged_split_types_);
|
||||
*output = new Dataset(
|
||||
ctx, input, dense_defaults, sparse_keys_, dense_keys_,
|
||||
std::move(key_to_output_index), std::move(config), num_parallel_calls,
|
||||
sparse_types_, dense_types_, dense_shapes_, output_types_,
|
||||
output_shapes_, deterministic_, has_ragged_keys_, ragged_keys_,
|
||||
ragged_value_types_, ragged_split_types_, op_version_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -182,10 +202,11 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
const DataTypeVector& dense_types,
|
||||
const std::vector<PartialTensorShape>& dense_shapes,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes, bool sloppy,
|
||||
bool has_ragged_keys, std::vector<string> ragged_keys,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
const DeterminismPolicy& deterministic, bool has_ragged_keys,
|
||||
std::vector<string> ragged_keys,
|
||||
const DataTypeVector& ragged_value_types,
|
||||
const DataTypeVector& ragged_split_types)
|
||||
const DataTypeVector& ragged_split_types, int op_version)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
dense_defaults_(std::move(dense_defaults)),
|
||||
@ -202,8 +223,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
dense_shapes_(dense_shapes),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes),
|
||||
sloppy_(sloppy),
|
||||
has_ragged_keys_(has_ragged_keys) {
|
||||
deterministic_(deterministic),
|
||||
has_ragged_keys_(has_ragged_keys),
|
||||
op_version_(op_version) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -213,9 +235,14 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
const string& prefix) const override {
|
||||
std::unique_ptr<ParallelMapFunctor> parse_example_functor =
|
||||
absl::make_unique<ParseExampleFunctor>(this);
|
||||
name_utils::IteratorPrefixParams params;
|
||||
params.op_version = op_version_;
|
||||
bool deterministic =
|
||||
deterministic_.IsDeterministic() || deterministic_.IsDefault();
|
||||
return NewParallelMapIterator(
|
||||
{this, strings::StrCat(prefix, "::ParseExample")}, input_,
|
||||
std::move(parse_example_functor), num_parallel_calls_, !sloppy_,
|
||||
{this, name_utils::IteratorPrefix("ParseExample", prefix, params)},
|
||||
input_, std::move(parse_example_functor), num_parallel_calls_,
|
||||
deterministic,
|
||||
/*preserve_cardinality=*/true);
|
||||
}
|
||||
|
||||
@ -228,7 +255,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "ParseExampleDatasetOp::Dataset";
|
||||
name_utils::DatasetDebugStringParams params;
|
||||
params.op_version = op_version_;
|
||||
return name_utils::DatasetDebugString("ParseExampleDataset", params);
|
||||
}
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
@ -257,60 +286,60 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
dense_defaults_nodes.emplace_back(node);
|
||||
}
|
||||
|
||||
AttrValue sparse_keys_attr;
|
||||
AttrValue dense_keys_attr;
|
||||
AttrValue sparse_types_attr;
|
||||
AttrValue dense_attr;
|
||||
AttrValue dense_shapes_attr;
|
||||
AttrValue sloppy_attr;
|
||||
std::vector<std::pair<StringPiece, AttrValue>> attrs;
|
||||
|
||||
AttrValue sparse_keys_attr;
|
||||
b->BuildAttrValue(sparse_keys_, &sparse_keys_attr);
|
||||
attrs.emplace_back("sparse_keys", sparse_keys_attr);
|
||||
|
||||
AttrValue dense_keys_attr;
|
||||
b->BuildAttrValue(dense_keys_, &dense_keys_attr);
|
||||
attrs.emplace_back("dense_keys", dense_keys_attr);
|
||||
|
||||
AttrValue sparse_types_attr;
|
||||
b->BuildAttrValue(sparse_types_, &sparse_types_attr);
|
||||
attrs.emplace_back("sparse_types", sparse_types_attr);
|
||||
|
||||
AttrValue dense_attr;
|
||||
b->BuildAttrValue(dense_types_, &dense_attr);
|
||||
attrs.emplace_back("Tdense", dense_attr);
|
||||
|
||||
AttrValue dense_shapes_attr;
|
||||
b->BuildAttrValue(dense_shapes_, &dense_shapes_attr);
|
||||
b->BuildAttrValue(sloppy_, &sloppy_attr);
|
||||
attrs.emplace_back("dense_shapes", dense_shapes_attr);
|
||||
|
||||
if (op_version_ == 1) {
|
||||
AttrValue sloppy_attr;
|
||||
b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
|
||||
attrs.emplace_back("sloppy", sloppy_attr);
|
||||
}
|
||||
if (op_version_ == 2) {
|
||||
AttrValue deterministic_attr;
|
||||
b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
|
||||
attrs.emplace_back("deterministic", deterministic_attr);
|
||||
}
|
||||
|
||||
if (has_ragged_keys_) {
|
||||
AttrValue ragged_keys_attr;
|
||||
AttrValue ragged_value_types_attr;
|
||||
AttrValue ragged_split_types_attr;
|
||||
b->BuildAttrValue(ragged_keys_, &ragged_keys_attr);
|
||||
b->BuildAttrValue(ragged_value_types_, &ragged_value_types_attr);
|
||||
b->BuildAttrValue(ragged_split_types_, &ragged_split_types_attr);
|
||||
attrs.emplace_back("ragged_keys", ragged_keys_attr);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this,
|
||||
{
|
||||
{0, input_graph_node},
|
||||
{1, num_parallle_calls_node},
|
||||
},
|
||||
{{2, dense_defaults_nodes}},
|
||||
{{"sparse_keys", sparse_keys_attr},
|
||||
{"dense_keys", dense_keys_attr},
|
||||
{"sparse_types", sparse_types_attr},
|
||||
{"Tdense", dense_attr},
|
||||
{"dense_shapes", dense_shapes_attr},
|
||||
{"sloppy", sloppy_attr},
|
||||
{"ragged_keys", ragged_keys_attr},
|
||||
{"ragged_value_types", ragged_value_types_attr},
|
||||
{"ragged_split_types", ragged_split_types_attr}},
|
||||
output));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this,
|
||||
{
|
||||
{0, input_graph_node},
|
||||
{1, num_parallle_calls_node},
|
||||
},
|
||||
{{2, dense_defaults_nodes}},
|
||||
{{"sparse_keys", sparse_keys_attr},
|
||||
{"dense_keys", dense_keys_attr},
|
||||
{"sparse_types", sparse_types_attr},
|
||||
{"Tdense", dense_attr},
|
||||
{"dense_shapes", dense_shapes_attr},
|
||||
{"sloppy", sloppy_attr}},
|
||||
output));
|
||||
AttrValue ragged_value_types_attr;
|
||||
b->BuildAttrValue(ragged_value_types_, &ragged_value_types_attr);
|
||||
attrs.emplace_back("ragged_value_types", ragged_value_types_attr);
|
||||
|
||||
AttrValue ragged_split_types_attr;
|
||||
b->BuildAttrValue(ragged_split_types_, &ragged_split_types_attr);
|
||||
attrs.emplace_back("ragged_split_types", ragged_split_types_attr);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this,
|
||||
{
|
||||
{0, input_graph_node},
|
||||
{1, num_parallle_calls_node},
|
||||
},
|
||||
{{2, dense_defaults_nodes}}, attrs,
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -443,14 +472,15 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
const std::vector<PartialTensorShape> dense_shapes_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
const bool sloppy_;
|
||||
const DeterminismPolicy deterministic_;
|
||||
const bool has_ragged_keys_;
|
||||
const int op_version_;
|
||||
};
|
||||
|
||||
const int graph_def_version_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
bool sloppy_;
|
||||
DeterminismPolicy deterministic_;
|
||||
std::vector<string> sparse_keys_;
|
||||
std::vector<string> dense_keys_;
|
||||
std::vector<string> ragged_keys_;
|
||||
@ -462,10 +492,13 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<bool> variable_length_;
|
||||
std::vector<std::size_t> elements_per_stride_;
|
||||
bool has_ragged_keys_;
|
||||
const int op_version_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU),
|
||||
ParseExampleDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ParseExampleDatasetV2").Device(DEVICE_CPU),
|
||||
ParseExampleDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalParseExampleDataset").Device(DEVICE_CPU),
|
||||
ParseExampleDatasetOp);
|
||||
|
@ -596,6 +596,27 @@ REGISTER_OP("ParseExampleDataset")
|
||||
.Attr("ragged_split_types: list({int32,int64}) >= 0 = []")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ParseExampleDatasetV2")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_parallel_calls: int64")
|
||||
.Input("dense_defaults: Tdense")
|
||||
.Output("handle: variant")
|
||||
.Attr("sparse_keys: list(string) >= 0")
|
||||
.Attr("dense_keys: list(string) >= 0")
|
||||
.Attr("sparse_types: list({float,int64,string}) >= 0")
|
||||
.Attr("Tdense: list({float,int64,string}) >= 0")
|
||||
.Attr("dense_shapes: list(shape) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1") // Output components will be
|
||||
// sorted by key (dense_keys and
|
||||
// sparse_keys combined) here.
|
||||
// "true", "false", or "default".
|
||||
.Attr("deterministic: string = 'default'")
|
||||
.Attr("ragged_keys: list(string) >= 0 = []")
|
||||
.Attr("ragged_value_types: list({float,int64,string}) >= 0 = []")
|
||||
.Attr("ragged_split_types: list({int32,int64}) >= 0 = []")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalParseExampleDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_parallel_calls: int64")
|
||||
|
@ -476,6 +476,7 @@ tf_py_test(
|
||||
name = "parse_example_dataset_test",
|
||||
size = "small",
|
||||
srcs = ["parse_example_dataset_test.py"],
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -1110,6 +1110,43 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase,
|
||||
expected_values=expected_output,
|
||||
create_iterator_twice=True)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
local_determinism=[None, True, False],
|
||||
global_determinism=[True, False])))
|
||||
def testDeterminism(self, local_determinism, global_determinism):
|
||||
num_elements = 1000
|
||||
batches = []
|
||||
for i in range(num_elements):
|
||||
example_i = example(features=features({
|
||||
"a": int64_feature([i]),
|
||||
}))
|
||||
batches.append([example_i.SerializeToString()])
|
||||
|
||||
test_features = {"a": parsing_ops.FixedLenFeature((), dtype=dtypes.int64)}
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(batches)
|
||||
dataset = dataset.apply(
|
||||
contrib_parsing_ops.parse_example_dataset(
|
||||
test_features,
|
||||
num_parallel_calls=10,
|
||||
deterministic=local_determinism))
|
||||
|
||||
opts = dataset_ops.Options()
|
||||
opts.experimental_deterministic = global_determinism
|
||||
dataset = dataset.with_options(opts)
|
||||
|
||||
expected = list(range(num_elements))
|
||||
actual = [elem["a"][0] for elem in self.getDatasetOutput(dataset)]
|
||||
|
||||
require_order = local_determinism or (local_determinism is None and
|
||||
global_determinism)
|
||||
if require_order:
|
||||
self.assertAllEqual(expected, actual)
|
||||
else:
|
||||
self.assertCountEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -31,13 +32,20 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that parses `example` dataset into a `dict` dataset."""
|
||||
|
||||
def __init__(self, input_dataset, features, num_parallel_calls):
|
||||
def __init__(self, input_dataset, features, num_parallel_calls,
|
||||
deterministic):
|
||||
self._input_dataset = input_dataset
|
||||
if not structure.are_compatible(
|
||||
input_dataset.element_spec,
|
||||
tensor_spec.TensorSpec([None], dtypes.string)):
|
||||
raise TypeError("Input dataset should be a dataset of vectors of strings")
|
||||
self._num_parallel_calls = num_parallel_calls
|
||||
if deterministic is None:
|
||||
self._deterministic = "default"
|
||||
elif deterministic:
|
||||
self._deterministic = "true"
|
||||
else:
|
||||
self._deterministic = "false"
|
||||
# pylint: disable=protected-access
|
||||
self._features = parsing_ops._prepend_none_dimension(features)
|
||||
# TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature
|
||||
@ -77,19 +85,35 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
self._element_spec[key] = ragged_tensor.RaggedTensorSpec(
|
||||
input_dataset_shape.concatenate([None]), value_type, 1, splits_type)
|
||||
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.parse_example_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
ragged_keys=self._ragged_keys,
|
||||
ragged_value_types=self._ragged_value_types,
|
||||
ragged_split_types=self._ragged_split_types,
|
||||
**self._flat_structure))
|
||||
if deterministic is not None or compat.forward_compatible(2020, 3, 6):
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.parse_example_dataset_v2(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
deterministic=self._deterministic,
|
||||
ragged_keys=self._ragged_keys,
|
||||
ragged_value_types=self._ragged_value_types,
|
||||
ragged_split_types=self._ragged_split_types,
|
||||
**self._flat_structure))
|
||||
else:
|
||||
variant_tensor = (
|
||||
gen_experimental_dataset_ops.parse_example_dataset(
|
||||
self._input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._num_parallel_calls,
|
||||
self._dense_defaults,
|
||||
self._sparse_keys,
|
||||
self._dense_keys,
|
||||
self._sparse_types,
|
||||
self._dense_shapes,
|
||||
ragged_keys=self._ragged_keys,
|
||||
ragged_value_types=self._ragged_value_types,
|
||||
ragged_split_types=self._ragged_split_types,
|
||||
**self._flat_structure))
|
||||
super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
@property
|
||||
@ -99,7 +123,7 @@ class _ParseExampleDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
# TODO(b/111553342): add arguments names and example names as well.
|
||||
@tf_export("data.experimental.parse_example_dataset")
|
||||
def parse_example_dataset(features, num_parallel_calls=1):
|
||||
def parse_example_dataset(features, num_parallel_calls=1, deterministic=None):
|
||||
"""A transformation that parses `Example` protos into a `dict` of tensors.
|
||||
|
||||
Parses a number of serialized `Example` protos given in `serialized`. We refer
|
||||
@ -119,6 +143,13 @@ def parse_example_dataset(features, num_parallel_calls=1):
|
||||
`VarLenFeature`, `RaggedFeature`, and `SparseFeature` values.
|
||||
num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
|
||||
representing the number of parsing processes to call in parallel.
|
||||
deterministic: (Optional.) A boolean controlling whether determinism
|
||||
should be traded for performance by allowing elements to be produced out
|
||||
of order if some parsing calls complete faster than others. If
|
||||
`deterministic` is `None`, the
|
||||
`tf.data.Options.experimental_deterministic` dataset option (`True` by
|
||||
default) is used to decide whether to produce elements
|
||||
deterministically.
|
||||
|
||||
Returns:
|
||||
A dataset transformation function, which can be passed to
|
||||
@ -132,7 +163,8 @@ def parse_example_dataset(features, num_parallel_calls=1):
|
||||
|
||||
def _apply_fn(dataset):
|
||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
||||
out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
|
||||
out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls,
|
||||
deterministic)
|
||||
if any(
|
||||
isinstance(feature, parsing_ops.SparseFeature) or
|
||||
(isinstance(feature, parsing_ops.RaggedFeature) and feature.partitions)
|
||||
|
@ -190,7 +190,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "parse_example_dataset"
|
||||
argspec: "args=[\'features\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
argspec: "args=[\'features\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "prefetch_to_device"
|
||||
|
@ -2660,6 +2660,10 @@ tf_module {
|
||||
name: "ParseExampleDataset"
|
||||
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'sloppy\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParseExampleDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'deterministic\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParseExampleV2"
|
||||
argspec: "args=[\'serialized\', \'names\', \'sparse_keys\', \'dense_keys\', \'ragged_keys\', \'dense_defaults\', \'num_sparse\', \'sparse_types\', \'ragged_value_types\', \'ragged_split_types\', \'dense_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -158,7 +158,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "parse_example_dataset"
|
||||
argspec: "args=[\'features\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
argspec: "args=[\'features\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "prefetch_to_device"
|
||||
|
@ -2660,6 +2660,10 @@ tf_module {
|
||||
name: "ParseExampleDataset"
|
||||
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'sloppy\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParseExampleDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'num_parallel_calls\', \'dense_defaults\', \'sparse_keys\', \'dense_keys\', \'sparse_types\', \'dense_shapes\', \'output_types\', \'output_shapes\', \'deterministic\', \'ragged_keys\', \'ragged_value_types\', \'ragged_split_types\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'[]\', \'[]\', \'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ParseExampleV2"
|
||||
argspec: "args=[\'serialized\', \'names\', \'sparse_keys\', \'dense_keys\', \'ragged_keys\', \'dense_defaults\', \'num_sparse\', \'sparse_types\', \'ragged_value_types\', \'ragged_split_types\', \'dense_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
x
Reference in New Issue
Block a user