Updated tf.io.parse_example and tf.io.parse_single_example to support ragged features.

* Added tf.io.RaggedFeature
* Updated tf.io.parse_example and tf.io.parse_single_example to accept RaggedFeature.

PiperOrigin-RevId: 270024281
This commit is contained in:
Edward Loper 2019-09-19 06:12:40 -07:00 committed by TensorFlower Gardener
parent 9d1bb322c8
commit 86c4003851
36 changed files with 2041 additions and 903 deletions

View File

@ -3462,6 +3462,15 @@ py_library(
],
)
py_library(
name = "parsing_config",
srcs = ["ops/parsing_config.py"],
srcs_version = "PY2AND3",
deps = [
":framework",
],
)
py_library(
name = "parsing_ops",
srcs = ["ops/parsing_ops.py"],
@ -3472,6 +3481,7 @@ py_library(
":framework",
":framework_for_generated_wrappers",
":math_ops",
":parsing_config",
":parsing_ops_gen",
":sparse_ops",
],

View File

@ -27,16 +27,21 @@ from google.protobuf import json_format
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@ -58,71 +63,58 @@ def flatten(list_of_lists):
return itertools.chain.from_iterable(list_of_lists)
def flatten_values_tensors_or_sparse(tensors_list):
"""Flatten each SparseTensor object into 3 Tensors for session.run()."""
return list(
flatten([[v.indices, v.values, v.dense_shape]
if isinstance(v, sparse_tensor.SparseTensor) else [v]
for v in tensors_list]))
def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
flat_output):
tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
i = 0 # Index into the flattened output of session.run()
for k, v in dict_tensors.items():
expected_v = expected_tensors[k]
def _compare_output_to_expected(tester, actual, expected):
tester.assertEqual(set(actual.keys()), set(expected.keys()))
for k, v in actual.items():
expected_v = expected[k]
tf_logging.info("Comparing key: %s", k)
if isinstance(v, sparse_tensor.SparseTensor):
# Three outputs for SparseTensor : indices, values, shape.
tester.assertEqual([k, len(expected_v)], [k, 3])
tester.assertAllEqual(expected_v[0], flat_output[i])
tester.assertAllEqual(expected_v[1], flat_output[i + 1])
tester.assertAllEqual(expected_v[2], flat_output[i + 2])
i += 3
tester.assertTrue(isinstance(expected_v, tuple))
tester.assertLen(expected_v, 3)
tester.assertAllEqual(v.indices, expected_v[0])
tester.assertAllEqual(v.values, expected_v[1])
tester.assertAllEqual(v.dense_shape, expected_v[2])
else:
# One output for standard Tensor.
tester.assertAllEqual(expected_v, flat_output[i])
i += 1
tester.assertAllEqual(v, expected_v)
@test_util.run_all_in_graph_and_eager_modes
class ParseExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
with self.cached_session() as sess:
if expected_err:
if expected_err:
if not context.executing_eagerly():
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
out = parsing_ops.parse_example(**kwargs)
sess.run(flatten_values_tensors_or_sparse(out.values()))
return
self.evaluate(parsing_ops.parse_example(**kwargs))
else:
# Returns dict w/ Tensors and SparseTensors.
out = parsing_ops.parse_example(**kwargs)
result = flatten_values_tensors_or_sparse(out.values())
# Check values.
tf_result = self.evaluate(result)
_compare_output_to_expected(self, out, expected_values, tf_result)
with self.assertRaises(Exception):
parsing_ops.parse_example(**kwargs)
return
else:
out = parsing_ops.parse_example(**kwargs)
_compare_output_to_expected(self, out, expected_values)
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
serialized = kwargs["serialized"]
batch_size = (
self.evaluate(serialized).size if isinstance(serialized, ops.Tensor)
else np.asarray(serialized).size)
for k, f in kwargs["features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
self.assertEqual(
tuple(out[k].get_shape().as_list()), (batch_size,) + f.shape)
elif isinstance(f, parsing_ops.VarLenFeature):
self.assertEqual(
tuple(out[k].indices.get_shape().as_list()), (None, 2))
self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,))
self.assertEqual(
tuple(out[k].dense_shape.get_shape().as_list()), (2,))
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
serialized = kwargs["serialized"]
batch_size = (
self.evaluate(serialized).size
if isinstance(serialized, ops.Tensor) else np.asarray(serialized).size)
for k, f in kwargs["features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
self.assertEqual(tuple(out[k].shape.as_list()), (batch_size,) + f.shape)
elif isinstance(f, parsing_ops.VarLenFeature):
if context.executing_eagerly():
out[k].indices.shape.assert_is_compatible_with([None, 2])
out[k].values.shape.assert_is_compatible_with([None])
out[k].dense_shape.shape.assert_is_compatible_with([2])
else:
self.assertEqual(out[k].indices.shape.as_list(), [None, 2])
self.assertEqual(out[k].values.shape.as_list(), [None])
self.assertEqual(out[k].dense_shape.shape.as_list(), [2])
@test_util.run_deprecated_v1
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testEmptySerializedWithAllDefaults(self):
sparse_name = "st_a"
a_name = "a"
@ -162,6 +154,7 @@ class ParseExampleTest(test.TestCase):
}
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testEmptySerializedWithoutDefaultsShouldFail(self):
input_features = {
"st_a":
@ -202,6 +195,7 @@ class ParseExampleTest(test.TestCase):
errors_impl.OpError,
"Name: in1, Feature: c \\(data type: float\\) is required"))
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testDenseNotMatchingShapeShouldFail(self):
original = [
example(features=features({
@ -226,6 +220,7 @@ class ParseExampleTest(test.TestCase):
expected_err=(errors_impl.OpError,
"Name: failing, Key: a, Index: 1. Number of float val"))
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testDenseDefaultNoShapeShouldFail(self):
original = [
example(features=features({
@ -245,7 +240,7 @@ class ParseExampleTest(test.TestCase):
},
expected_err=(ValueError, "Missing shape for feature a"))
@test_util.run_deprecated_v1
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingSparse(self):
original = [
example(features=features({
@ -290,6 +285,7 @@ class ParseExampleTest(test.TestCase):
}
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingSparseFeature(self):
original = [
example(
@ -334,6 +330,7 @@ class ParseExampleTest(test.TestCase):
}
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingSparseFeatureReuse(self):
original = [
example(
@ -377,6 +374,7 @@ class ParseExampleTest(test.TestCase):
}
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContaining3DSparseFeature(self):
original = [
example(
@ -429,6 +427,7 @@ class ParseExampleTest(test.TestCase):
}
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingDense(self):
aname = "a"
bname = "b*has+a:tricky_name"
@ -467,6 +466,7 @@ class ParseExampleTest(test.TestCase):
# This test is identical as the previous one except
# for the creation of 'serialized'.
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingDenseWithConcat(self):
aname = "a"
bname = "b*has+a:tricky_name"
@ -514,6 +514,7 @@ class ParseExampleTest(test.TestCase):
}
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingDenseScalar(self):
original = [
example(features=features({
@ -538,6 +539,7 @@ class ParseExampleTest(test.TestCase):
}
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingDenseWithDefaults(self):
original = [
example(features=features({
@ -574,7 +576,7 @@ class ParseExampleTest(test.TestCase):
}
}, expected_output)
@test_util.run_deprecated_v1
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
expected_st_a = ( # indices, values, shape
np.empty((0, 2), dtype=np.int64), # indices
@ -635,7 +637,7 @@ class ParseExampleTest(test.TestCase):
},
expected_output)
@test_util.run_deprecated_v1
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
expected_idx = ( # indices, values, shape
np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
@ -692,8 +694,9 @@ class ParseExampleTest(test.TestCase):
expected_str = copy.deepcopy(truth_str)
# Delete some intermediate entries
for i in range(batch_size):
# Delete some intermediate entries. (Skip the first entry, to ensure that
# we have at least one entry with length 2, to get the expected padding.)
for i in range(1, batch_size):
col = 1
if np.random.rand() < 0.25:
# w.p. 25%, drop out the second entry
@ -740,12 +743,13 @@ class ParseExampleTest(test.TestCase):
}
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingVarLenDenseLargerBatch(self):
np.random.seed(3456)
for batch_size in (1, 10, 20, 100, 256):
self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
@test_util.run_deprecated_v1
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingVarLenDense(self):
aname = "a"
bname = "b"
@ -939,36 +943,267 @@ class ParseExampleTest(test.TestCase):
"Unsupported: FixedLenSequenceFeature requires "
"allow_missing to be True."))
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingRaggedFeatureWithNoPartitions(self):
original = [
example(features=features({"rt_c": float_feature([3, 4])})),
example(
features=features({
"rt_c": float_feature([]), # empty float list
})),
example(
features=features({
"rt_d": feature(), # feature with nothing in it
})),
example(
features=features({
"rt_c": float_feature([1, 2, -1]),
"rt_d": bytes_feature([b"hi"])
}))
]
serialized = [m.SerializeToString() for m in original]
test_features = {
"rt_c":
parsing_ops.RaggedFeature(dtype=dtypes.float32),
"rt_d":
parsing_ops.RaggedFeature(
dtype=dtypes.string, row_splits_dtype=dtypes.int64)
}
expected_rt_c = ragged_factory_ops.constant(
[[3.0, 4.0], [], [], [1.0, 2.0, -1.0]],
dtype=dtypes.float32,
row_splits_dtype=dtypes.int32)
expected_rt_d = ragged_factory_ops.constant([[], [], [], [b"hi"]])
expected_output = {
"rt_c": expected_rt_c,
"rt_d": expected_rt_d,
}
self._test(
{
"serialized": ops.convert_to_tensor(serialized),
"features": test_features
}, expected_output)
# Test with a large enough batch to ensure that the minibatch size is >1.
batch_serialized = serialized * 64
self.assertEqual(expected_rt_c.row_splits.dtype, np.int32)
batch_expected_out = {
"rt_c": ragged_concat_ops.concat([expected_rt_c] * 64, axis=0),
"rt_d": ragged_concat_ops.concat([expected_rt_d] * 64, axis=0)
}
self.assertEqual(batch_expected_out["rt_c"].row_splits.dtype, dtypes.int32)
self._test(
{
"serialized": ops.convert_to_tensor(batch_serialized),
"features": test_features
}, batch_expected_out)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingRaggedFeature(self):
original = [
example(
features=features({
# rt = [[3], [4, 5, 6]]
"rt_values": float_feature([3, 4, 5, 6]),
"rt_splits": int64_feature([0, 1, 4]),
"rt_lengths": int64_feature([1, 3]),
"rt_starts": int64_feature([0, 1]),
"rt_limits": int64_feature([1, 4]),
"rt_rowids": int64_feature([0, 1, 1, 1]),
})),
example(
features=features({
# rt = []
"rt_values": float_feature([]),
"rt_splits": int64_feature([0]),
"rt_lengths": int64_feature([]),
"rt_starts": int64_feature([]),
"rt_limits": int64_feature([]),
"rt_rowids": int64_feature([]),
})),
example(
features=features({
# rt = []
"rt_values": feature(), # feature with nothing in it
"rt_splits": int64_feature([0]),
"rt_lengths": feature(),
"rt_starts": feature(),
"rt_limits": feature(),
"rt_rowids": feature(),
})),
example(
features=features({
# rt = [[1.0, 2.0, -1.0], [], [8.0, 9.0], [5.0]]
"rt_values": float_feature([1, 2, -1, 8, 9, 5]),
"rt_splits": int64_feature([0, 3, 3, 5, 6]),
"rt_lengths": int64_feature([3, 0, 2, 1]),
"rt_starts": int64_feature([0, 3, 3, 5]),
"rt_limits": int64_feature([3, 3, 5, 6]),
"rt_rowids": int64_feature([0, 0, 0, 2, 2, 3]),
}))
]
serialized = ops.convert_to_tensor(
[m.SerializeToString() for m in original])
test_features = {
"rt1":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.RowSplits("rt_splits")],
dtype=dtypes.float32),
"rt2":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.RowLengths("rt_lengths")],
dtype=dtypes.float32),
"rt3":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.RowStarts("rt_starts")],
dtype=dtypes.float32),
"rt4":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.RowLimits("rt_limits")],
dtype=dtypes.float32),
"rt5":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.ValueRowIds("rt_rowids")],
dtype=dtypes.float32),
"uniform1":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.UniformRowLength(2)],
dtype=dtypes.float32),
"uniform2":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[
parsing_ops.RaggedFeature.UniformRowLength(2),
parsing_ops.RaggedFeature.RowSplits("rt_splits")
],
dtype=dtypes.float32),
}
expected_rt = ragged_factory_ops.constant(
[[[3], [4, 5, 6]], [], [], [[1, 2, -1], [], [8, 9], [5]]],
dtype=dtypes.float32,
row_splits_dtype=dtypes.int32)
expected_uniform1 = ragged_factory_ops.constant(
[[[3, 4], [5, 6]], [], [], [[1, 2], [-1, 8], [9, 5]]],
ragged_rank=1,
dtype=dtypes.float32,
row_splits_dtype=dtypes.int32)
expected_uniform2 = ragged_factory_ops.constant(
[[[[3], [4, 5, 6]]], [], [], [[[1, 2, -1], []], [[8, 9], [5]]]],
dtype=dtypes.float32,
row_splits_dtype=dtypes.int32)
expected_output = {
"rt1": expected_rt,
"rt2": expected_rt,
"rt3": expected_rt,
"rt4": expected_rt,
"rt5": expected_rt,
"uniform1": expected_uniform1,
"uniform2": expected_uniform2,
}
self._test({
"serialized": serialized,
"features": test_features
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSerializedContainingNestedRaggedFeature(self):
"""Test RaggedFeature with 3 partitions."""
original = [
# rt shape: [(batch), 2, None, None]
example(
features=features({
# rt = [[[[1]], [[2, 3], [4]]], [[], [[5, 6, 7]]]]
"rt_values": float_feature([1, 2, 3, 4, 5, 6, 7]),
"lengths_axis2": int64_feature([1, 2, 0, 1]),
"lengths_axis3": int64_feature([1, 2, 1, 3]),
"splits_axis3": int64_feature([0, 1, 3, 4, 7]),
})),
example(
features=features({
# rt = [[[[1, 2, 3], [4]], [[5], [6], [7, 8]]]]
"rt_values": float_feature([1, 2, 3, 4, 5, 6, 7, 8]),
"lengths_axis2": int64_feature([2, 3]),
"lengths_axis3": int64_feature([3, 1, 1, 1, 2]),
"splits_axis3": int64_feature([0, 3, 4, 5, 6, 8]),
}))
]
serialized = ops.convert_to_tensor(
[m.SerializeToString() for m in original])
test_features = {
"rt1":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[
parsing_ops.RaggedFeature.UniformRowLength(2),
parsing_ops.RaggedFeature.RowLengths("lengths_axis2"),
parsing_ops.RaggedFeature.RowSplits("splits_axis3"),
],
dtype=dtypes.float32,
row_splits_dtype=dtypes.int64,
),
}
expected_rt = ragged_factory_ops.constant(
[[[[[1]], [[2, 3], [4]]], [[], [[5, 6, 7]]]],
[[[[1, 2, 3], [4]], [[5], [6], [7, 8]]]]],
dtype=dtypes.float32,
row_splits_dtype=dtypes.int64)
expected_output = {
"rt1": expected_rt,
}
self._test({
"serialized": serialized,
"features": test_features
}, expected_output)
@test_util.run_all_in_graph_and_eager_modes
class ParseSingleExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
out = parsing_ops.parse_single_example(**kwargs)
sess.run(flatten_values_tensors_or_sparse(out.values()))
else:
# Returns dict w/ Tensors and SparseTensors.
out = parsing_ops.parse_single_example(**kwargs)
# Check values.
tf_result = sess.run(flatten_values_tensors_or_sparse(out.values()))
_compare_output_to_expected(self, out, expected_values, tf_result)
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
self.evaluate(parsing_ops.parse_single_example(**kwargs))
else:
out = parsing_ops.parse_single_example(**kwargs)
_compare_output_to_expected(self, out, expected_values)
# Check shapes.
for k, f in kwargs["features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
self.assertEqual(
tuple(out[k].get_shape()), tensor_shape.as_shape(f.shape))
elif isinstance(f, parsing_ops.VarLenFeature):
self.assertEqual(
tuple(out[k].indices.get_shape().as_list()), (None, 1))
self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,))
self.assertEqual(
tuple(out[k].dense_shape.get_shape().as_list()), (1,))
# Check shapes.
for k, f in kwargs["features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
self.assertEqual(
tuple(out[k].get_shape()), tensor_shape.as_shape(f.shape))
elif isinstance(f, parsing_ops.VarLenFeature):
if context.executing_eagerly():
self.assertEqual(tuple(out[k].indices.shape.as_list()), (2, 1))
self.assertEqual(tuple(out[k].values.shape.as_list()), (2,))
self.assertEqual(tuple(out[k].dense_shape.shape.as_list()), (1,))
else:
self.assertEqual(tuple(out[k].indices.shape.as_list()), (None, 1))
self.assertEqual(tuple(out[k].values.shape.as_list()), (None,))
self.assertEqual(tuple(out[k].dense_shape.shape.as_list()), (1,))
@test_util.run_deprecated_v1
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSingleExampleWithSparseAndSparseFeatureAndDense(self):
original = example(
features=features({
@ -981,6 +1216,30 @@ class ParseSingleExampleTest(test.TestCase):
serialized = original.SerializeToString()
a_default = [1, 2, 3]
b_default = np.random.rand(3, 3).astype(bytes)
test_features = {
"st_a":
parsing_ops.VarLenFeature(dtypes.float32),
"sp":
parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
"a":
parsing_ops.FixedLenFeature((1, 3),
dtypes.int64,
default_value=a_default),
"b":
parsing_ops.FixedLenFeature((3, 3),
dtypes.string,
default_value=b_default),
# Feature "c" must be provided, since it has no default_value.
"c":
parsing_ops.FixedLenFeature(2, dtypes.float32),
"d":
parsing_ops.FixedLenSequenceFeature([],
dtypes.float32,
allow_missing=True)
}
expected_st_a = (
np.array([[0], [1]], dtype=np.int64), # indices
np.array([3.0, 4.0], dtype=np.float32), # values
@ -990,8 +1249,6 @@ class ParseSingleExampleTest(test.TestCase):
np.array([[0], [3]], dtype=np.int64), np.array(["a", "b"], dtype="|S"),
np.array([13], dtype=np.int64)) # max_values = 13
a_default = [1, 2, 3]
b_default = np.random.rand(3, 3).astype(bytes)
expected_output = {
"st_a": expected_st_a,
"sp": expected_sp,
@ -1005,29 +1262,162 @@ class ParseSingleExampleTest(test.TestCase):
{
"example_names": ops.convert_to_tensor("in1"),
"serialized": ops.convert_to_tensor(serialized),
"features": {
"st_a":
parsing_ops.VarLenFeature(dtypes.float32),
"sp":
parsing_ops.SparseFeature(["idx"], "val", dtypes.string,
[13]),
"a":
parsing_ops.FixedLenFeature(
(1, 3), dtypes.int64, default_value=a_default),
"b":
parsing_ops.FixedLenFeature(
(3, 3), dtypes.string, default_value=b_default),
# Feature "c" must be provided, since it has no default_value.
"c":
parsing_ops.FixedLenFeature(2, dtypes.float32),
"d":
parsing_ops.FixedLenSequenceFeature(
[], dtypes.float32, allow_missing=True)
}
},
expected_output)
"features": test_features,
}, expected_output)
# Note: if example_names is None, then a different code-path gets used.
self._test(
{
"serialized": ops.convert_to_tensor(serialized),
"features": test_features,
}, expected_output)
@test_util.with_forward_compatibility_horizons(None, [2019, 10, 31])
def testSingleExampleWithAllFeatureTypes(self):
original = example(
features=features({
# FixLen features
"c": float_feature([3, 4]),
"d": float_feature([0.0, 1.0]),
# Sparse features
"val": bytes_feature([b"a", b"b"]), # for sp
"idx": int64_feature([0, 3]), # for sp
"st_a": float_feature([3.0, 4.0]),
# Ragged features
"rt_1d": float_feature([3.0, 4.0]),
"rt_values": float_feature([5, 6, 7]), # for rt_2d
"rt_splits": int64_feature([0, 1, 1, 3]), # for rt_2d
"rt_lengths": int64_feature([1, 0, 2]), # for rt_2d
"rt_starts": int64_feature([0, 1, 1]), # for rt_2d
"rt_limits": int64_feature([1, 1, 3]), # for rt_2d
"rt_rowids": int64_feature([0, 2, 2]), # for rt_2d
"rt_splits2": int64_feature([0, 2, 3]), # for rt_3d
}))
serialized = original.SerializeToString()
a_default = [1, 2, 3]
b_default = np.random.rand(3, 3).astype(bytes)
test_features = {
"st_a":
parsing_ops.VarLenFeature(dtypes.float32),
"sp":
parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
"a":
parsing_ops.FixedLenFeature((1, 3),
dtypes.int64,
default_value=a_default),
"b":
parsing_ops.FixedLenFeature((3, 3),
dtypes.string,
default_value=b_default),
# Feature "c" must be provided, since it has no default_value.
"c":
parsing_ops.FixedLenFeature(2, dtypes.float32),
"d":
parsing_ops.FixedLenSequenceFeature([],
dtypes.float32,
allow_missing=True),
"rt_1d":
parsing_ops.RaggedFeature(dtypes.float32),
"rt_2d_with_splits":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.RowSplits("rt_splits")],
dtype=dtypes.float32),
"rt_2d_with_lengths":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.RowLengths("rt_lengths")],
dtype=dtypes.float32),
"rt_2d_with_starts":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.RowStarts("rt_starts")],
dtype=dtypes.float32),
"rt_2d_with_limits":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.RowLimits("rt_limits")],
dtype=dtypes.float32),
"rt_2d_with_rowids":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.ValueRowIds("rt_rowids")],
dtype=dtypes.float32),
"rt_2d_with_uniform_row_length":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[parsing_ops.RaggedFeature.UniformRowLength(1)],
dtype=dtypes.float32),
"rt_3d":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[
parsing_ops.RaggedFeature.RowSplits("rt_splits2"),
parsing_ops.RaggedFeature.RowSplits("rt_splits")
],
dtype=dtypes.float32),
"rt_3d_with_uniform_row_length":
parsing_ops.RaggedFeature(
value_key="rt_values",
partitions=[
parsing_ops.RaggedFeature.UniformRowLength(1),
parsing_ops.RaggedFeature.RowSplits("rt_splits")
],
dtype=dtypes.float32),
}
expected_st_a = (
np.array([[0], [1]], dtype=np.int64), # indices
np.array([3.0, 4.0], dtype=np.float32), # values
np.array([2], dtype=np.int64)) # shape: max_values = 2
expected_sp = ( # indices, values, shape
np.array([[0], [3]], dtype=np.int64), np.array(["a", "b"], dtype="|S"),
np.array([13], dtype=np.int64)) # max_values = 13
expected_rt_1d = constant_op.constant([3, 4], dtypes.float32)
expected_rt_2d = ragged_factory_ops.constant([[5], [], [6, 7]],
dtype=dtypes.float32)
expected_rt_2d_uniform = constant_op.constant([[5], [6], [7]],
dtype=dtypes.float32)
expected_rt_3d = ragged_factory_ops.constant([[[5], []], [[6, 7]]],
dtype=dtypes.float32)
expected_rt_3d_with_uniform = (
ragged_tensor.RaggedTensor.from_uniform_row_length(
expected_rt_2d, uniform_row_length=1))
expected_output = {
"st_a": expected_st_a,
"sp": expected_sp,
"a": [a_default],
"b": b_default,
"c": np.array([3, 4], dtype=np.float32),
"d": np.array([0.0, 1.0], dtype=np.float32),
"rt_1d": expected_rt_1d,
"rt_2d_with_splits": expected_rt_2d,
"rt_2d_with_lengths": expected_rt_2d,
"rt_2d_with_starts": expected_rt_2d,
"rt_2d_with_limits": expected_rt_2d,
"rt_2d_with_rowids": expected_rt_2d,
"rt_2d_with_uniform_row_length": expected_rt_2d_uniform,
"rt_3d": expected_rt_3d,
"rt_3d_with_uniform_row_length": expected_rt_3d_with_uniform,
}
self._test(
{
"example_names": ops.convert_to_tensor("in1"),
"serialized": ops.convert_to_tensor(serialized),
"features": test_features,
}, expected_output)
@test_util.run_all_in_graph_and_eager_modes
class ParseSequenceExampleTest(test.TestCase):
def testCreateSequenceExample(self):
@ -1061,70 +1451,55 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values = expected_feat_list_values or {}
expected_length_values = expected_length_values or {}
with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
if batch:
c_out, fl_out, _ = parsing_ops.parse_sequence_example(**kwargs)
else:
c_out, fl_out = parsing_ops.parse_single_sequence_example(**kwargs)
if c_out:
sess.run(flatten_values_tensors_or_sparse(c_out.values()))
if fl_out:
sess.run(flatten_values_tensors_or_sparse(fl_out.values()))
else:
# Returns dicts w/ Tensors and SparseTensors.
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
if batch:
(context_out, feat_list_out,
lengths_out) = parsing_ops.parse_sequence_example(**kwargs)
self.evaluate(parsing_ops.parse_sequence_example(**kwargs))
else:
(context_out,
feat_list_out) = parsing_ops.parse_single_sequence_example(**kwargs)
lengths_out = {}
self.evaluate(parsing_ops.parse_single_sequence_example(**kwargs))
else:
if batch:
(context_out, feat_list_out,
lengths_out) = parsing_ops.parse_sequence_example(**kwargs)
else:
(context_out,
feat_list_out) = parsing_ops.parse_single_sequence_example(**kwargs)
lengths_out = {}
context_result = sess.run(
flatten_values_tensors_or_sparse(
context_out.values())) if context_out else []
feat_list_result = sess.run(
flatten_values_tensors_or_sparse(
feat_list_out.values())) if feat_list_out else []
lengths_result = sess.run(
flatten_values_tensors_or_sparse(
lengths_out.values())) if lengths_out else []
# Check values.
_compare_output_to_expected(self, context_out, expected_context_values,
context_result)
_compare_output_to_expected(self, feat_list_out,
expected_feat_list_values, feat_list_result)
_compare_output_to_expected(self, lengths_out, expected_length_values,
lengths_result)
# Check values.
_compare_output_to_expected(self, context_out, expected_context_values)
_compare_output_to_expected(self, feat_list_out,
expected_feat_list_values)
_compare_output_to_expected(self, lengths_out, expected_length_values)
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
if "context_features" in kwargs:
for k, f in kwargs["context_features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
if batch:
self.assertEqual(
tuple(context_out[k].get_shape().as_list()[1:]), f.shape)
else:
self.assertEqual(
tuple(context_out[k].get_shape().as_list()), f.shape)
elif isinstance(f, parsing_ops.VarLenFeature) and batch:
self.assertEqual(
tuple(context_out[k].indices.get_shape().as_list()), (None, 2))
self.assertEqual(
tuple(context_out[k].values.get_shape().as_list()), (None,))
self.assertEqual(
tuple(context_out[k].dense_shape.get_shape().as_list()), (2,))
elif isinstance(f, parsing_ops.VarLenFeature) and not batch:
self.assertEqual(
tuple(context_out[k].indices.get_shape().as_list()), (None, 1))
self.assertEqual(
tuple(context_out[k].values.get_shape().as_list()), (None,))
self.assertEqual(
tuple(context_out[k].dense_shape.get_shape().as_list()), (1,))
# Check shapes; if serialized is a Tensor we need its size to
# properly check.
if "context_features" in kwargs:
for k, f in kwargs["context_features"].items():
if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
if batch:
self.assertEqual(tuple(context_out[k].shape.as_list()[1:]), f.shape)
else:
self.assertEqual(tuple(context_out[k].shape.as_list()), f.shape)
elif isinstance(f, parsing_ops.VarLenFeature) and batch:
if context.executing_eagerly():
context_out[k].indices.shape.assert_is_compatible_with([None, 2])
context_out[k].values.shape.assert_is_compatible_with([None])
context_out[k].dense_shape.shape.assert_is_compatible_with([2])
else:
self.assertEqual(context_out[k].indices.shape.as_list(), [None, 2])
self.assertEqual(context_out[k].values.shape.as_list(), [None])
self.assertEqual(context_out[k].dense_shape.shape.as_list(), [2])
elif isinstance(f, parsing_ops.VarLenFeature) and not batch:
if context.executing_eagerly():
context_out[k].indices.shape.assert_is_compatible_with([None, 1])
context_out[k].values.shape.assert_is_compatible_with([None])
context_out[k].dense_shape.shape.assert_is_compatible_with([1])
else:
self.assertEqual(context_out[k].indices.shape.as_list(), [None, 1])
self.assertEqual(context_out[k].values.shape.as_list(), [None])
self.assertEqual(context_out[k].dense_shape.shape.as_list(), [1])
def _testBoth(self,
kwargs,
@ -1187,7 +1562,6 @@ class ParseSequenceExampleTest(test.TestCase):
expected_err=expected_err,
batch=True)
@test_util.run_deprecated_v1
def testSequenceExampleWithSparseAndDenseContext(self):
original = sequence_example(
context=features({
@ -1231,7 +1605,6 @@ class ParseSequenceExampleTest(test.TestCase):
},
expected_context_values=expected_context_output)
@test_util.run_deprecated_v1
def testSequenceExampleWithMultipleSizeFeatureLists(self):
original = sequence_example(
feature_lists=feature_lists({
@ -1295,7 +1668,6 @@ class ParseSequenceExampleTest(test.TestCase):
},
expected_feat_list_values=expected_feature_list_output)
@test_util.run_deprecated_v1
def testSequenceExampleWithoutDebugName(self):
original = sequence_example(
feature_lists=feature_lists({
@ -1353,7 +1725,6 @@ class ParseSequenceExampleTest(test.TestCase):
},
expected_feat_list_values=expected_feature_list_output)
@test_util.run_deprecated_v1
def testSequenceExampleWithSparseAndDenseFeatureLists(self):
original = sequence_example(
feature_lists=feature_lists({
@ -1412,7 +1783,6 @@ class ParseSequenceExampleTest(test.TestCase):
},
expected_feat_list_values=expected_feature_list_output)
@test_util.run_deprecated_v1
def testSequenceExampleWithEmptyFeatureInFeatureLists(self):
original = sequence_example(
feature_lists=feature_lists({
@ -1553,7 +1923,6 @@ class ParseSequenceExampleTest(test.TestCase):
" feature_list_dense_missing_assumed_empty or"
" feature_list_dense_defaults?"))
@test_util.run_deprecated_v1
def testSequenceExampleBatch(self):
first = sequence_example(
feature_lists=feature_lists({
@ -1616,6 +1985,7 @@ class ParseSequenceExampleTest(test.TestCase):
batch=True)
@test_util.run_all_in_graph_and_eager_modes
class DecodeRawTest(test.TestCase):
def _decode_v1(self, words):
@ -1686,30 +2056,30 @@ class DecodeRawTest(test.TestCase):
self.assertAllEqual(expected, observed)
@test_util.run_all_in_graph_and_eager_modes
class DecodeJSONExampleTest(test.TestCase):
def _testRoundTrip(self, examples):
with self.cached_session() as sess:
examples = np.array(examples, dtype=np.object)
examples = np.array(examples, dtype=np.object)
json_tensor = constant_op.constant(
[json_format.MessageToJson(m) for m in examples.flatten()],
shape=examples.shape,
dtype=dtypes.string)
binary_tensor = parsing_ops.decode_json_example(json_tensor)
binary_val = self.evaluate(binary_tensor)
json_tensor = constant_op.constant(
[json_format.MessageToJson(m) for m in examples.flatten()],
shape=examples.shape,
dtype=dtypes.string)
binary_tensor = parsing_ops.decode_json_example(json_tensor)
binary_val = self.evaluate(binary_tensor)
if examples.shape:
self.assertShapeEqual(binary_val, json_tensor)
for input_example, output_binary in zip(
np.array(examples).flatten(), binary_val.flatten()):
output_example = example_pb2.Example()
output_example.ParseFromString(output_binary)
self.assertProtoEquals(input_example, output_example)
else:
if examples.shape:
self.assertShapeEqual(binary_val, json_tensor)
for input_example, output_binary in zip(
np.array(examples).flatten(), binary_val.flatten()):
output_example = example_pb2.Example()
output_example.ParseFromString(binary_val)
self.assertProtoEquals(examples.item(), output_example)
output_example.ParseFromString(output_binary)
self.assertProtoEquals(input_example, output_example)
else:
output_example = example_pb2.Example()
output_example.ParseFromString(binary_val)
self.assertProtoEquals(examples.item(), output_example)
def testEmptyTensor(self):
self._testRoundTrip([])
@ -1778,10 +2148,13 @@ class DecodeJSONExampleTest(test.TestCase):
})),
])
@test_util.run_deprecated_v1
def testInvalidSyntax(self):
with self.cached_session() as sess:
json_tensor = constant_op.constant(["{]"])
json_tensor = constant_op.constant(["{]"])
if context.executing_eagerly():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Error while parsing JSON"):
parsing_ops.decode_json_example(json_tensor)
else:
binary_tensor = parsing_ops.decode_json_example(json_tensor)
with self.assertRaisesOpError("Error while parsing JSON"):
self.evaluate(binary_tensor)

View File

@ -3258,6 +3258,38 @@ def _convert_parse_single_example(pfor_input):
return [wrap(t, True, True) for t in nest.flatten(output)]
@RegisterPFor("ParseExampleV2")
def _convert_parse_example_v2(pfor_input):
serialized = pfor_input.stacked_input(0)
sparse_keys = pfor_input.unstacked_input(2)
dense_keys = pfor_input.unstacked_input(3)
ragged_keys = pfor_input.unstacked_input(4)
dense_defaults = [
pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs)
]
num_sparse = pfor_input.get_attr("num_sparse")
sparse_types = pfor_input.get_attr("sparse_types")
ragged_value_types = pfor_input.get_attr("ragged_value_types")
ragged_split_types = pfor_input.get_attr("ragged_split_types")
dense_shapes = pfor_input.get_attr("dense_shapes")
if serialized.shape.ndims not in (None, 1):
raise ValueError("ParseExampleV2 can only be converted if `serialized` "
"is scalar.")
output = gen_parsing_ops.parse_example_v2(
serialized=serialized,
names=[],
sparse_keys=sparse_keys,
dense_keys=dense_keys,
ragged_keys=ragged_keys,
dense_defaults=dense_defaults,
num_sparse=num_sparse,
sparse_types=sparse_types,
ragged_value_types=ragged_value_types,
ragged_split_types=ragged_split_types,
dense_shapes=dense_shapes)
return [wrap(t, True, True) for t in nest.flatten(output)]
# functional_ops

View File

@ -0,0 +1,861 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Feature configuration for tf.io.parse_example."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import tf_logging
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# Avoid circular dependencies with RaggedTensor.
# TODO(b/141170488) Refactor ragged modules so this is unnecessary.
ragged_tensor = LazyLoader(
"ragged_tensor", globals(),
"tensorflow.python.ops.ragged.ragged_tensor")
ragged_math_ops = LazyLoader(
"ragged_math_ops", globals(),
"tensorflow.python.ops.ragged.ragged_math_ops")
# TODO(b/122887740) Refactor code:
# * Move input verification to feature configuration objects (e.g.,
# VarLenFeature should check that dtype is a valid dtype).
# * Add an _add_feature() method to each feature configuration object
# (rather than using a dispatch table in _ParseOpParams._add_feature).
# * Update _construct_tensors_for_composite_features() to call a method
# on the feature object (rather than using dispatch).
@tf_export("io.VarLenFeature", v1=["VarLenFeature", "io.VarLenFeature"])
class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
"""Configuration for parsing a variable-length input feature.
Fields:
dtype: Data type of input.
"""
pass
@tf_export("io.RaggedFeature")
class RaggedFeature(
collections.namedtuple(
"RaggedFeature",
["dtype", "value_key", "partitions", "row_splits_dtype", "validate"])):
"""Configuration for passing a RaggedTensor input feature.
`value_key` specifies the feature key for a variable-length list of values;
and `partitions` specifies zero or more feature keys for partitioning those
values into higher dimensions. Each element of `partitions` must be one of
the following:
* `tf.io.RaggedFeature.RowSplits(key: string)`
* `tf.io.RaggedFeature.RowLengths(key: string)`
* `tf.io.RaggedFeature.RowStarts(key: string)`
* `tf.io.RaggedFeature.RowLimits(key: string)`
* `tf.io.RaggedFeature.ValueRowIds(key: string)`
* `tf.io.RaggedFeature.UniformRowLength(length: int)`.
Where `key` is a feature key whose values are used to partition the values.
Partitions are listed from outermost to innermost.
* If `len(partitions) == 0` (the default), then:
* A feature from a single `tf.Example` is parsed into a 1D `tf.Tensor`.
* A feature from a batch of `tf.Example`s is parsed into a 2D
`tf.RaggedTensor`, where the outer dimension is the batch dimension, and
the inner (ragged) dimension is the feature length in each example.
* If `len(partitions) == 1`, then:
* A feature from a single `tf.Example` is parsed into a 2D
`tf.RaggedTensor`, where the values taken from the `value_key` are
separated into rows using the partition key.
* A feature from a batch of `tf.Example`s is parsed into a 3D
`tf.RaggedTensor`, where the outer dimension is the batch dimension,
the two inner dimensions are formed by separating the `value_key` values
from each example into rows using that example's partition key.
* If `len(partitions) > 1`, then:
* A feature from a single `tf.Example` is parsed into a `tf.RaggedTensor`
whose rank is `len(partitions)+1`, and whose ragged_rank is
`len(partitions)`.
* A feature from a batch of `tf.Example`s is parsed into a `tf.RaggedTensor`
whose rank is `len(partitions)+2` and whose ragged_rank is
`len(partitions)+1`, where the outer dimension is the batch dimension.
There is one exception: if the final (i.e., innermost) element(s) of
`partitions` are `UniformRowLength`s, then the values are simply reshaped (as
a higher-dimensional `tf.Tensor`), rather than being wrapped in a
`tf.RaggedTensor`.
#### Examples
>>> import google.protobuf.text_format as pbtext
>>> example_batch = [
... pbtext.Merge(r'''
... features {
... feature {key: "v" value {int64_list {value: [3, 1, 4, 1, 5, 9]}}}
... feature {key: "s1" value {int64_list {value: [0, 2, 3, 3, 6]}}}
... feature {key: "s2" value {int64_list {value: [0, 2, 3, 4]}}}
... }''', tf.train.Example()).SerializeToString(),
... pbtext.Merge(r'''
... features {
... feature {key: "v" value {int64_list {value: [2, 7, 1, 8, 2, 8, 1]}}}
... feature {key: "s1" value {int64_list {value: [0, 3, 4, 5, 7]}}}
... feature {key: "s2" value {int64_list {value: [0, 1, 1, 4]}}}
... }''', tf.train.Example()).SerializeToString()]
>>> features = {
... # Zero partitions: returns 1D tf.Tensor for each Example.
... 'f1': tf.io.RaggedFeature(value_key="v", dtype=tf.int64),
... # One partition: returns 2D tf.RaggedTensor for each Example.
... 'f2': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[
... tf.io.RaggedFeature.RowSplits("s1")]),
... # Two partitions: returns 3D tf.RaggedTensor for each Example.
... 'f3': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[
... tf.io.RaggedFeature.RowSplits("s2"),
... tf.io.RaggedFeature.RowSplits("s1")])
... }
>>> tf.io.parse_single_example(example_batch[0], features)
{'f1': <tf.Tensor: ..., numpy=array([3, 1, 4, 1, 5, 9])>,
'f2': <tf.RaggedTensor [[3, 1], [4], [], [1, 5, 9]]>,
'f3': <tf.RaggedTensor [[[3, 1], [4]], [[]], [[1, 5, 9]]]>}
>>> tf.io.parse_example(example_batch, features)
{'f1': <tf.RaggedTensor [[3, 1, 4, 1, 5, 9], [2, 7, 1, 8, 2, 8, 1]]>,
'f2': <tf.RaggedTensor [[[3, 1], [4], [], [1, 5, 9]],
[[2, 7, 1], [8], [2], [8, 1]]]>,
'f3': <tf.RaggedTensor [[[[3, 1], [4]], [[]], [[1, 5, 9]]],
[[[2, 7, 1]], [], [[8], [2], [8, 1]]]]>}
Fields:
dtype: Data type of the `RaggedTensor`. Must be one of:
`tf.dtypes.int64`, `tf.dtypes.float32`, `tf.dtypes.string`.
value_key: (Optional.) Key for a `Feature` in the input `Example`, whose
parsed `Tensor` will be the resulting `RaggedTensor.flat_values`. If
not specified, then it defaults to the key for this `RaggedFeature`.
partitions: (Optional.) A list of objects specifying the row-partitioning
tensors (from outermost to innermost). Each entry in this list must be
one of:
* `tf.io.RaggedFeature.RowSplits(key: string)`
* `tf.io.RaggedFeature.RowLengths(key: string)`
* `tf.io.RaggedFeature.RowStarts(key: string)`
* `tf.io.RaggedFeature.RowLimits(key: string)`
* `tf.io.RaggedFeature.ValueRowIds(key: string)`
* `tf.io.RaggedFeature.UniformRowLength(length: int)`.
Where `key` is a key for a `Feature` in the input `Example`, whose parsed
`Tensor` will be the resulting row-partitioning tensor.
row_splits_dtype: (Optional.) Data type for the row-partitioning tensor(s).
One of `int32` or `int64`. Defaults to `int32`.
validate: (Optional.) Boolean indicating whether or not to validate that
the input values form a valid RaggedTensor. Defaults to `False`.
"""
# pylint: disable=invalid-name
RowSplits = collections.namedtuple("RowSplits", ["key"])
RowLengths = collections.namedtuple("RowLengths", ["key"])
RowStarts = collections.namedtuple("RowStarts", ["key"])
RowLimits = collections.namedtuple("RowLimits", ["key"])
ValueRowIds = collections.namedtuple("ValueRowIds", ["key"])
UniformRowLength = collections.namedtuple("UniformRowLength", ["length"])
# pylint: enable=invalid-name
_PARTITION_TYPES = (RowSplits, RowLengths, RowStarts, RowLimits, ValueRowIds,
UniformRowLength)
def __new__(cls,
dtype,
value_key=None,
partitions=(),
row_splits_dtype=dtypes.int32,
validate=False):
if value_key is not None:
if not isinstance(value_key, str):
raise ValueError("value_key must be a string; got %r" % value_key)
if not value_key:
raise ValueError("value_key may not be empty")
dtype = dtypes.as_dtype(dtype)
if dtype not in (dtypes.int64, dtypes.float32, dtypes.string):
raise ValueError("dtypes must be int64, float32, or bytes; got %r" %
dtype)
row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
if row_splits_dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("row_splits_dtype must be int32 or int64; got %r" %
row_splits_dtype)
if not isinstance(partitions, (list, tuple)):
raise TypeError("partitions must be a list or tuple")
for partition in partitions:
if not isinstance(partition, cls._PARTITION_TYPES):
raise TypeError("partitions must be a list of partition objects %s;"
" got: %r" % (cls._PARTITION_TYPES, partition))
if not isinstance(validate, bool):
raise TypeError("validate must be a bool; got %r" % validate)
return super(RaggedFeature, cls).__new__(cls, dtype, value_key, partitions,
row_splits_dtype, validate)
@tf_export("io.SparseFeature", v1=["io.SparseFeature", "SparseFeature"])
class SparseFeature(
collections.namedtuple(
"SparseFeature",
["index_key", "value_key", "dtype", "size", "already_sorted"])):
"""Configuration for parsing a sparse input feature from an `Example`.
Note, preferably use `VarLenFeature` (possibly in combination with a
`SequenceExample`) in order to parse out `SparseTensor`s instead of
`SparseFeature` due to its simplicity.
Closely mimicking the `SparseTensor` that will be obtained by parsing an
`Example` with a `SparseFeature` config, a `SparseFeature` contains a
* `value_key`: The name of key for a `Feature` in the `Example` whose parsed
`Tensor` will be the resulting `SparseTensor.values`.
* `index_key`: A list of names - one for each dimension in the resulting
`SparseTensor` whose `indices[i][dim]` indicating the position of
the `i`-th value in the `dim` dimension will be equal to the `i`-th value in
the Feature with key named `index_key[dim]` in the `Example`.
* `size`: A list of ints for the resulting `SparseTensor.dense_shape`.
For example, we can represent the following 2D `SparseTensor`
```python
SparseTensor(indices=[[3, 1], [20, 0]],
values=[0.5, -1.0]
dense_shape=[100, 3])
```
with an `Example` input proto
```python
features {
feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
feature { key: "ix0" value { int64_list { value: [ 3, 20 ] } } }
feature { key: "ix1" value { int64_list { value: [ 1, 0 ] } } }
}
```
and `SparseFeature` config with 2 `index_key`s
```python
SparseFeature(index_key=["ix0", "ix1"],
value_key="val",
dtype=tf.float32,
size=[100, 3])
```
Fields:
index_key: A single string name or a list of string names of index features.
For each key the underlying feature's type must be `int64` and its length
must always match that of the `value_key` feature.
To represent `SparseTensor`s with a `dense_shape` of `rank` higher than 1
a list of length `rank` should be used.
value_key: Name of value feature. The underlying feature's type must
be `dtype` and its length must always match that of all the `index_key`s'
features.
dtype: Data type of the `value_key` feature.
size: A Python int or list thereof specifying the dense shape. Should be a
list if and only if `index_key` is a list. In that case the list must be
equal to the length of `index_key`. Each for each entry `i` all values in
the `index_key`[i] feature must be in `[0, size[i])`.
already_sorted: A Python boolean to specify whether the values in
`value_key` are already sorted by their index position. If so skip
sorting. False by default (optional).
"""
def __new__(cls, index_key, value_key, dtype, size, already_sorted=False):
return super(SparseFeature, cls).__new__(
cls, index_key, value_key, dtype, size, already_sorted)
@tf_export("io.FixedLenFeature", v1=["io.FixedLenFeature", "FixedLenFeature"])
class FixedLenFeature(collections.namedtuple(
"FixedLenFeature", ["shape", "dtype", "default_value"])):
"""Configuration for parsing a fixed-length input feature.
To treat sparse input as dense, provide a `default_value`; otherwise,
the parse functions will fail on any examples missing this feature.
Fields:
shape: Shape of input data.
dtype: Data type of input.
default_value: Value to be used if an example is missing this feature. It
must be compatible with `dtype` and of the specified `shape`.
"""
def __new__(cls, shape, dtype, default_value=None):
return super(FixedLenFeature, cls).__new__(
cls, shape, dtype, default_value)
@tf_export("io.FixedLenSequenceFeature",
v1=["io.FixedLenSequenceFeature", "FixedLenSequenceFeature"])
class FixedLenSequenceFeature(collections.namedtuple(
"FixedLenSequenceFeature",
["shape", "dtype", "allow_missing", "default_value"])):
"""Configuration for parsing a variable-length input feature into a `Tensor`.
The resulting `Tensor` of parsing a single `SequenceExample` or `Example` has
a static `shape` of `[None] + shape` and the specified `dtype`.
The resulting `Tensor` of parsing a `batch_size` many `Example`s has
a static `shape` of `[batch_size, None] + shape` and the specified `dtype`.
The entries in the `batch` from different `Examples` will be padded with
`default_value` to the maximum length present in the `batch`.
To treat a sparse input as dense, provide `allow_missing=True`; otherwise,
the parse functions will fail on any examples missing this feature.
Fields:
shape: Shape of input data for dimension 2 and higher. First dimension is
of variable length `None`.
dtype: Data type of input.
allow_missing: Whether to allow this feature to be missing from a feature
list item. Is available only for parsing `SequenceExample` not for
parsing `Examples`.
default_value: Scalar value to be used to pad multiple `Example`s to their
maximum length. Irrelevant for parsing a single `Example` or
`SequenceExample`. Defaults to "" for dtype string and 0 otherwise
(optional).
"""
def __new__(cls, shape, dtype, allow_missing=False, default_value=None):
return super(FixedLenSequenceFeature, cls).__new__(
cls, shape, dtype, allow_missing, default_value)
class _ParseOpParams(object):
"""Raw parameters used by `gen_parsing_ops`.
Attributes:
sparse_keys: A list of string keys in the examples' features. The results
for these keys will be returned as `SparseTensor` objects.
sparse_types: A list of `DTypes` of the same length as `sparse_keys`. Only
`tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
(`BytesList`) are supported.
dense_keys: A list of string keys in the examples' features. The results for
these keys will be returned as `Tensor`s
dense_types: A list of DTypes of the same length as `dense_keys`. Only
`tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
(`BytesList`) are supported.
dense_defaults: A dict mapping string keys to `Tensor`s. The keys of the
dict must match the dense_keys of the feature.
dense_shapes: A list of tuples with the same length as `dense_keys`. The
shape of the data for each dense feature referenced by `dense_keys`.
Required for any input tensors identified by `dense_keys`. Must be either
fully defined, or may contain an unknown first dimension. An unknown first
dimension means the feature is treated as having a variable number of
blocks, and the output shape along this dimension is considered unknown at
graph build time. Padding is applied for minibatch elements smaller than
the maximum number of blocks for the given feature along this dimension.
ragged_keys: A list of string keys in the examples' features. The
results for these keys will be returned as `RaggedTensor` objects.
ragged_value_types: A list of `DTypes` of the same length as `ragged_keys`,
specifying the value type for each ragged feature. Must be one of:
`tf.float32`, `tf.int64`, `tf.string`.
ragged_split_types: A list of `DTypes` of the same length as `ragged_keys`,
specifying the row_splits type for each ragged feature. Must be one of:
`tf.int32`, `tf.int64`.
dense_shapes_as_proto: dense_shapes converted to TensorShapeProto.
dense_defaults_vec: A vector of `Tensor`s containing the default values,
corresponding 1:1 with `dense_keys`.
num_features: The total number of feature keys.
"""
def __init__(self,
sparse_keys=None,
sparse_types=None,
dense_keys=None,
dense_types=None,
dense_defaults=None,
dense_shapes=None,
ragged_keys=None,
ragged_value_types=None,
ragged_split_types=None):
# Note: we use an OrderedDict for dense_defaults, to ensure consistent
# graph construction order for _e2e_test.
dense_defaults = (
collections.OrderedDict() if dense_defaults is None else dense_defaults)
sparse_keys = [] if sparse_keys is None else sparse_keys
sparse_types = [] if sparse_types is None else sparse_types
dense_keys = [] if dense_keys is None else dense_keys
dense_types = [] if dense_types is None else dense_types
dense_shapes = ([[]] *
len(dense_keys) if dense_shapes is None else dense_shapes)
ragged_keys = [] if ragged_keys is None else ragged_keys
ragged_value_types = ([]
if ragged_value_types is None else ragged_value_types)
ragged_split_types = ([]
if ragged_split_types is None else ragged_split_types)
self.sparse_keys = sparse_keys
self.sparse_types = [dtypes.as_dtype(t) for t in sparse_types]
self.dense_keys = dense_keys
self.dense_types = [dtypes.as_dtype(t) for t in dense_types]
self.dense_shapes = [tensor_shape.as_shape(s) for s in dense_shapes]
self.dense_defaults = dense_defaults
self.ragged_keys = ragged_keys
self.ragged_value_types = [dtypes.as_dtype(t) for t in ragged_value_types]
self.ragged_split_types = [dtypes.as_dtype(t) for t in ragged_split_types]
self._validate()
@classmethod
def from_features(cls, features, types):
"""Builds _ParseOpParams for a given set of features and allowed types.
Args:
features: A `dict` mapping feature keys to objects of a type in `types`.
types: Type of features to allow, among `FixedLenFeature`,
`VarLenFeature`, `SparseFeature`, and `FixedLenSequenceFeature`.
Returns:
A `_ParseOpParams` containing the raw parameters for `gen_parsing_ops`.
Raises:
ValueError: if `features` contains an item not in `types`, or an invalid
feature.
ValueError: if sparse and dense key sets intersect.
ValueError: if input lengths do not match up.
"""
params = cls()
if features:
# NOTE: We iterate over sorted keys to keep things deterministic.
for key in sorted(features.keys()):
feature = features[key]
if not isinstance(feature, tuple(types)):
raise ValueError("Unsupported %s %s." %
(type(feature).__name__, feature))
params._add_feature(key, feature) # pylint: disable=protected-access
return params
@property
def dense_shapes_as_proto(self):
return [shape.as_proto() for shape in self.dense_shapes]
@property
def num_features(self):
return len(self.dense_keys) + len(self.sparse_keys) + len(self.ragged_keys)
@property
def dense_defaults_vec(self):
return [
self._make_dense_default(k, s, t)
for k, s, t in zip(self.dense_keys, self.dense_shapes, self.dense_types)
]
def _make_dense_default(self, key, shape, dtype):
"""Construct the default value tensor for a specified dense feature.
Args:
key: The key string identifying the dense feature.
shape: The dense feature's shape.
dtype: The dense feature's dtype.
Returns:
A Tensor.
"""
default_value = self.dense_defaults.get(key)
if (shape.ndims is not None and shape.ndims > 0 and
shape.dims[0].value is None):
# Variable stride dense shape, the default value should be a
# scalar padding value.
if default_value is None:
default_value = ops.convert_to_tensor(
"" if dtype == dtypes.string else 0, dtype=dtype)
else:
# Reshape to a scalar to ensure user gets an error if they
# provide a tensor that's not intended to be a padding value
# (0 or 2+ elements).
key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
default_value = ops.convert_to_tensor(
default_value, dtype=dtype, name=key_name)
default_value = array_ops.reshape(default_value, [])
else:
if default_value is None:
default_value = constant_op.constant([], dtype=dtype)
elif not isinstance(default_value, ops.Tensor):
key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
default_value = ops.convert_to_tensor(
default_value, dtype=dtype, name=key_name)
default_value = array_ops.reshape(default_value, shape)
return default_value
def _add_feature(self, key, feature):
"""Adds the specified feature to this ParseOpParams."""
if isinstance(feature, VarLenFeature):
self._add_varlen_feature(key, feature)
elif isinstance(feature, SparseFeature):
self._add_sparse_feature(key, feature)
elif isinstance(feature, FixedLenFeature):
self._add_fixed_len_feature(key, feature)
elif isinstance(feature, FixedLenSequenceFeature):
self._add_fixed_len_sequence_feature(key, feature)
elif isinstance(feature, RaggedFeature):
self._add_ragged_feature(key, feature)
else:
raise ValueError("Invalid feature %s:%s." % (key, feature))
def _add_varlen_feature(self, key, feature):
"""Adds a VarLenFeature."""
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
self._add_sparse_key(key, feature.dtype)
def _add_sparse_key(self, key, dtype):
"""Adds a sparse key & dtype, checking for duplicates."""
if key in self.sparse_keys:
original_dtype = self.sparse_types[self.sparse_keys.index(key)]
if original_dtype != dtype:
raise ValueError("Conflicting type %s vs %s for feature %s." %
(original_dtype, dtype, key))
else:
self.sparse_keys.append(key)
self.sparse_types.append(dtype)
def _add_sparse_feature(self, key, feature):
"""Adds a SparseFeature."""
if not feature.index_key:
raise ValueError("Missing index_key for SparseFeature %s." % (feature,))
if not feature.value_key:
raise ValueError("Missing value_key for SparseFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
index_keys = feature.index_key
if isinstance(index_keys, str):
index_keys = [index_keys]
elif len(index_keys) > 1:
tf_logging.warning("SparseFeature is a complicated feature config "
"and should only be used after careful "
"consideration of VarLenFeature.")
for index_key in sorted(index_keys):
self._add_sparse_key(index_key, dtypes.int64)
self._add_sparse_key(feature.value_key, feature.dtype)
def _add_fixed_len_feature(self, key, feature):
"""Adds a FixedLenFeature."""
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
raise ValueError("Missing shape for feature %s." % key)
feature_tensor_shape = tensor_shape.as_shape(feature.shape)
if (feature.shape and feature_tensor_shape.ndims and
feature_tensor_shape.dims[0].value is None):
raise ValueError("First dimension of shape for feature %s unknown. "
"Consider using FixedLenSequenceFeature." % key)
if (feature.shape is not None and
not feature_tensor_shape.is_fully_defined()):
raise ValueError("All dimensions of shape for feature %s need to be "
"known but received %s." % (key, str(feature.shape)))
self.dense_keys.append(key)
self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
self.dense_types.append(feature.dtype)
if feature.default_value is not None:
self.dense_defaults[key] = feature.default_value
def _add_fixed_len_sequence_feature(self, key, feature):
"""Adds a FixedLenSequenceFeature."""
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
raise ValueError("Missing shape for feature %s." % key)
self.dense_keys.append(key)
self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
self.dense_types.append(feature.dtype)
if feature.allow_missing:
self.dense_defaults[key] = None
if feature.default_value is not None:
self.dense_defaults[key] = feature.default_value
def _add_ragged_key(self, key, value_type, split_type):
"""Adds a ragged key & dtype, checking for duplicates."""
if key in self.ragged_keys:
original_value_type = self.ragged_value_types[self.ragged_keys.index(key)]
original_split_type = self.ragged_split_types[self.ragged_keys.index(key)]
if original_value_type != value_type:
raise ValueError("Conflicting type %s vs %s for feature %s." %
(original_value_type, value_type, key))
if original_split_type != split_type:
raise ValueError("Conflicting partition type %s vs %s for feature %s." %
(original_split_type, split_type, key))
else:
self.ragged_keys.append(key)
self.ragged_value_types.append(value_type)
self.ragged_split_types.append(split_type)
def _add_ragged_feature(self, key, feature):
"""Adds a RaggedFeature."""
value_key = key if feature.value_key is None else feature.value_key
self._add_ragged_key(value_key, feature.dtype, feature.row_splits_dtype)
for partition in feature.partitions:
if not isinstance(partition, RaggedFeature.UniformRowLength):
self._add_ragged_key(partition.key, dtypes.int64,
feature.row_splits_dtype)
def _validate(self):
"""Validates the features in this ParseOpParams."""
if len(self.dense_shapes) != len(self.dense_keys):
raise ValueError(
"len(self.dense_shapes) != len(self.dense_keys): %d vs %d" %
(len(self.dense_shapes), len(self.dense_keys)))
if len(self.dense_types) != len(self.dense_keys):
raise ValueError(
"len(self.dense_types) != len(self.dense_keys): %d vs %d" %
(len(self.dense_types), len(self.dense_keys)))
if len(self.sparse_types) != len(self.sparse_keys):
raise ValueError(
"len(self.sparse_types) != len(self.sparse_keys): %d vs %d" %
(len(self.sparse_types), len(self.sparse_keys)))
if len(self.ragged_value_types) != len(self.ragged_keys):
raise ValueError(
"len(self.ragged_value_types) != len(self.ragged_keys): %d vs %d" %
(len(self.ragged_value_types), len(self.ragged_keys)))
if len(self.ragged_split_types) != len(self.ragged_keys):
raise ValueError(
"len(self.ragged_split_types) != len(self.ragged_keys): %d vs %d" %
(len(self.ragged_split_types), len(self.ragged_keys)))
dense_key_set = set(self.dense_keys)
sparse_key_set = set(self.sparse_keys)
ragged_key_set = set(self.ragged_keys)
if not dense_key_set.isdisjoint(sparse_key_set):
raise ValueError(
"Dense and sparse keys must not intersect; intersection: %s" %
dense_key_set.intersection(sparse_key_set))
if not dense_key_set.isdisjoint(ragged_key_set):
raise ValueError(
"Dense and ragged keys must not intersect; intersection: %s" %
dense_key_set.intersection(ragged_key_set))
if not ragged_key_set.isdisjoint(sparse_key_set):
raise ValueError(
"Ragged and sparse keys must not intersect; intersection: %s" %
ragged_key_set.intersection(sparse_key_set))
def _construct_tensors_for_composite_features(features, tensor_dict):
"""Creates tensors for SparseFeatures and RaggedFeatures.
Constructs new dict based on `tensor_dict`.
For each key in `features` whose value is a `SparseFeature`:
* Looks up that SparseFeature's value_key and index_keys in tensor_dict.
* Uses those tensors to construct a single SparseTensor.
* Stores that SparseTensor in the output dict under the same key.
For each key in `features` whose value is a `RaggedFeature`:
* Looks up that RaggedFeature's value_key and partition keys in tensor_dict.
* Uses those tensors to construct a single RaggedTensor.
* Stores that RaggedTensor in the output dict under the same key.
For any other key in `features`:
* Copies that key and its value from tensor_dict to the output dictionary.
Args:
features: A `dict` mapping feature keys to `SparseFeature` or
`RaggedFeature` values. Values of other types will be ignored.
tensor_dict: A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and
`RaggedTensor` values. Expected to contain keys of the `SparseFeature`s'
`index_key`s and `value_key`s and mapping them to `SparseTensor`s.
Returns:
A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and
`RaggedTensor` values. Similar to `tensor_dict` except each `SparseFeature`
in `features` results in a single `SparseTensor`; and each `RaggedFeature`
in `features` results in a single `RaggedTensor`.
"""
tensor_dict = dict(tensor_dict) # Do not modify argument passed in.
updates = {}
for key in sorted(features.keys()):
feature = features[key]
if isinstance(feature, SparseFeature):
# Construct SparseTensors for SparseFeatures
if isinstance(feature.index_key, str):
sp_ids = tensor_dict[feature.index_key]
else:
sp_ids = [tensor_dict[index_key] for index_key in feature.index_key]
sp_values = tensor_dict[feature.value_key]
updates[key] = sparse_ops.sparse_merge(
sp_ids,
sp_values,
vocab_size=feature.size,
already_sorted=feature.already_sorted)
elif isinstance(feature, RaggedFeature):
# Construct RaggedTensors for RaggedFeatures.
value_key = key if feature.value_key is None else feature.value_key
rt = tensor_dict[value_key]
if isinstance(rt, ragged_tensor.RaggedTensor):
# We processed a vector of serialized tf.Examples.
for partition in reversed(feature.partitions):
rt = _add_batched_ragged_partition(rt, partition, tensor_dict,
feature.validate)
else:
# We processed a single serialized tf.Example.
for partition in reversed(feature.partitions):
rt = _add_ragged_partition(rt, partition, tensor_dict,
feature.row_splits_dtype, feature.validate)
updates[key] = rt
# Process updates after all composite tensors have been constructed (in case
# multiple features use the same value_key, and one uses that key as its
# feature key).
tensor_dict.update(updates)
# Remove tensors from dictionary that were only used to construct
# tensors for SparseFeature or RaggedTensor.
for key in set(tensor_dict) - set(features):
del tensor_dict[key]
return tensor_dict
def _add_ragged_partition(values, partition, tensor_dict, row_splits_dtype,
validate):
"""Creates a RaggedTensor from a values tensor and a partition tensor.
Args:
values: The values tensor for the new RaggedTensor.
partition: The partition configuration object. Specifies the key that
should be used to look up the partition tensor (unless partition is a
RaggedFeature.UniformRowLength, in which case there is no partition
tensor).
tensor_dict: The dictionary mapping keys to tensors.
row_splits_dtype: The dtype for the partition tensor.
validate: Whether to validate that the values form a valid RaggedTensor.
Returns:
A new RaggedTensor formed from the values and partition tensors.
"""
if isinstance(partition, RaggedFeature.UniformRowLength):
if isinstance(values, ragged_tensor.RaggedTensor):
length = ops.convert_to_tensor(partition.length, dtype=row_splits_dtype)
return ragged_tensor.RaggedTensor.from_uniform_row_length(
values, length, validate=validate)
else:
return array_ops.reshape(values, array_ops.concat(
[[-1, partition.length], array_ops.shape(values)[1:]], axis=0))
else:
partition_t = math_ops.cast(tensor_dict[partition.key], row_splits_dtype)
if isinstance(partition, RaggedFeature.RowSplits):
return ragged_tensor.RaggedTensor.from_row_splits(
values, partition_t, validate=validate)
elif isinstance(partition, RaggedFeature.RowLengths):
return ragged_tensor.RaggedTensor.from_row_lengths(
values, partition_t, validate=validate)
elif isinstance(partition, RaggedFeature.RowStarts):
return ragged_tensor.RaggedTensor.from_row_starts(
values, partition_t, validate=validate)
elif isinstance(partition, RaggedFeature.RowLimits):
return ragged_tensor.RaggedTensor.from_row_limits(
values, partition_t, validate=validate)
elif isinstance(partition, RaggedFeature.ValueRowIds):
return ragged_tensor.RaggedTensor.from_value_rowids(
values, partition_t, validate=validate)
raise ValueError("Unhandled partition type %r" % partition)
def _add_batched_ragged_partition(rt, partition, tensor_dict, validate):
"""Adds a batched ragged partition tensor to a batched ragged tensor.
Args:
rt: A RaggedTensor with shape [batch_size, ...].
partition: The partition configuration object. Specifies the key that
should be used to look up the partition tensor (unless partition is a
RaggedFeature.UniformRowLength, in which case there is no partition
tensor). The specified tensor must have shape [batch_size, ...].
tensor_dict: The dictionary mapping keys to tensors.
validate: Whether to validate that the values form a valid RaggedTensor.
Returns:
A new RaggedTensor where each batch item `rt[i]` has been partitioned
using the `partition_t[i]`.
"""
if isinstance(partition, RaggedFeature.UniformRowLength):
if rt.ragged_rank > 1:
length = ops.convert_to_tensor(partition.length, rt.row_splits.dtype)
return ragged_tensor.RaggedTensor.from_row_splits(
ragged_tensor.RaggedTensor.from_uniform_row_length(
rt.values, length, validate=validate),
rt.row_splits // length,
validate=validate)
else:
reshaped_vals = array_ops.reshape(rt.values, array_ops.concat(
[[-1, partition.length], array_ops.shape(rt.values)[1:]], axis=0))
return ragged_tensor.RaggedTensor.from_row_splits(
reshaped_vals, rt.row_splits // partition.length, validate=validate)
partition_t = tensor_dict[partition.key]
if partition_t.values.dtype != rt.row_splits.dtype:
partition_t = math_ops.cast(partition_t, rt.row_splits.dtype)
if isinstance(partition, (RaggedFeature.RowSplits, RaggedFeature.RowLimits)):
if isinstance(partition, RaggedFeature.RowSplits):
partition_t = partition_t[:, 1:]
adjusted_limits = partition_t.values + array_ops.repeat(
rt.row_starts(), partition_t.row_lengths())
return partition_t.with_values(
ragged_tensor.RaggedTensor.from_row_limits(
rt.values, adjusted_limits, validate=validate))
elif isinstance(partition, RaggedFeature.RowStarts):
adjusted_starts = partition_t.values + array_ops.repeat(
rt.row_starts(), partition_t.row_lengths())
return partition_t.with_values(
ragged_tensor.RaggedTensor.from_row_starts(
rt.values, adjusted_starts, validate=validate))
elif isinstance(partition, RaggedFeature.RowLengths):
return partition_t.with_values(
ragged_tensor.RaggedTensor.from_row_lengths(
rt.values, partition_t.values, validate=validate))
elif isinstance(partition, RaggedFeature.ValueRowIds):
nrows = math_ops.maximum( # number of rows in each batch item
ragged_math_ops.reduce_max(partition_t + 1, axis=1), 0)
adjusted_rowids = partition_t.values + array_ops.repeat(
math_ops.cumsum(nrows, exclusive=True), partition_t.row_lengths())
return ragged_tensor.RaggedTensor.from_row_lengths(
ragged_tensor.RaggedTensor.from_value_rowids(
rt.values, adjusted_rowids, validate=validate),
nrows,
validate=validate)
raise ValueError("Unhandled partition type %r" % partition)
def _build_ragged_tensors(serialized_shape, ragged_values, ragged_row_splits):
"""Builds RaggedTensors from the outputs of a parse op."""
if serialized_shape.ndims == 0:
return ragged_values
else:
return [
ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False)
for (val, split) in zip(ragged_values, ragged_row_splits)
]

View File

@ -18,24 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.compat import compat
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_parsing_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import parsing_config
# go/tf-wildcard-import
# pylint: disable=wildcard-import,undefined-variable
from tensorflow.python.ops.gen_parsing_ops import *
# pylint: enable=wildcard-import,undefined-variable
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@ -47,432 +43,21 @@ ops.NotDifferentiable("SerializeTensor")
ops.NotDifferentiable("StringToNumber")
@tf_export("io.VarLenFeature", v1=["VarLenFeature", "io.VarLenFeature"])
class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
"""Configuration for parsing a variable-length input feature.
Fields:
dtype: Data type of input.
"""
pass
VarLenFeature = parsing_config.VarLenFeature
RaggedFeature = parsing_config.RaggedFeature
SparseFeature = parsing_config.SparseFeature
FixedLenFeature = parsing_config.FixedLenFeature
FixedLenSequenceFeature = parsing_config.FixedLenSequenceFeature
# pylint: disable=protected-access
_ParseOpParams = parsing_config._ParseOpParams
_construct_tensors_for_composite_features = (
parsing_config._construct_tensors_for_composite_features)
# pylint: enable=protected-access
@tf_export("io.SparseFeature", v1=["io.SparseFeature", "SparseFeature"])
class SparseFeature(
collections.namedtuple(
"SparseFeature",
["index_key", "value_key", "dtype", "size", "already_sorted"])):
"""Configuration for parsing a sparse input feature from an `Example`.
Note, preferably use `VarLenFeature` (possibly in combination with a
`SequenceExample`) in order to parse out `SparseTensor`s instead of
`SparseFeature` due to its simplicity.
Closely mimicking the `SparseTensor` that will be obtained by parsing an
`Example` with a `SparseFeature` config, a `SparseFeature` contains a
* `value_key`: The name of key for a `Feature` in the `Example` whose parsed
`Tensor` will be the resulting `SparseTensor.values`.
* `index_key`: A list of names - one for each dimension in the resulting
`SparseTensor` whose `indices[i][dim]` indicating the position of
the `i`-th value in the `dim` dimension will be equal to the `i`-th value in
the Feature with key named `index_key[dim]` in the `Example`.
* `size`: A list of ints for the resulting `SparseTensor.dense_shape`.
For example, we can represent the following 2D `SparseTensor`
```python
SparseTensor(indices=[[3, 1], [20, 0]],
values=[0.5, -1.0]
dense_shape=[100, 3])
```
with an `Example` input proto
```python
features {
feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
feature { key: "ix0" value { int64_list { value: [ 3, 20 ] } } }
feature { key: "ix1" value { int64_list { value: [ 1, 0 ] } } }
}
```
and `SparseFeature` config with 2 `index_key`s
```python
SparseFeature(index_key=["ix0", "ix1"],
value_key="val",
dtype=tf.float32,
size=[100, 3])
```
Fields:
index_key: A single string name or a list of string names of index features.
For each key the underlying feature's type must be `int64` and its length
must always match that of the `value_key` feature.
To represent `SparseTensor`s with a `dense_shape` of `rank` higher than 1
a list of length `rank` should be used.
value_key: Name of value feature. The underlying feature's type must
be `dtype` and its length must always match that of all the `index_key`s'
features.
dtype: Data type of the `value_key` feature.
size: A Python int or list thereof specifying the dense shape. Should be a
list if and only if `index_key` is a list. In that case the list must be
equal to the length of `index_key`. Each for each entry `i` all values in
the `index_key`[i] feature must be in `[0, size[i])`.
already_sorted: A Python boolean to specify whether the values in
`value_key` are already sorted by their index position. If so skip
sorting. False by default (optional).
"""
def __new__(cls, index_key, value_key, dtype, size, already_sorted=False):
return super(SparseFeature, cls).__new__(
cls, index_key, value_key, dtype, size, already_sorted)
@tf_export("io.FixedLenFeature", v1=["io.FixedLenFeature", "FixedLenFeature"])
class FixedLenFeature(collections.namedtuple(
"FixedLenFeature", ["shape", "dtype", "default_value"])):
"""Configuration for parsing a fixed-length input feature.
To treat sparse input as dense, provide a `default_value`; otherwise,
the parse functions will fail on any examples missing this feature.
Fields:
shape: Shape of input data.
dtype: Data type of input.
default_value: Value to be used if an example is missing this feature. It
must be compatible with `dtype` and of the specified `shape`.
"""
def __new__(cls, shape, dtype, default_value=None):
return super(FixedLenFeature, cls).__new__(
cls, shape, dtype, default_value)
@tf_export("io.FixedLenSequenceFeature",
v1=["io.FixedLenSequenceFeature", "FixedLenSequenceFeature"])
class FixedLenSequenceFeature(collections.namedtuple(
"FixedLenSequenceFeature",
["shape", "dtype", "allow_missing", "default_value"])):
"""Configuration for parsing a variable-length input feature into a `Tensor`.
The resulting `Tensor` of parsing a single `SequenceExample` or `Example` has
a static `shape` of `[None] + shape` and the specified `dtype`.
The resulting `Tensor` of parsing a `batch_size` many `Example`s has
a static `shape` of `[batch_size, None] + shape` and the specified `dtype`.
The entries in the `batch` from different `Examples` will be padded with
`default_value` to the maximum length present in the `batch`.
To treat a sparse input as dense, provide `allow_missing=True`; otherwise,
the parse functions will fail on any examples missing this feature.
Fields:
shape: Shape of input data for dimension 2 and higher. First dimension is
of variable length `None`.
dtype: Data type of input.
allow_missing: Whether to allow this feature to be missing from a feature
list item. Is available only for parsing `SequenceExample` not for
parsing `Examples`.
default_value: Scalar value to be used to pad multiple `Example`s to their
maximum length. Irrelevant for parsing a single `Example` or
`SequenceExample`. Defaults to "" for dtype string and 0 otherwise
(optional).
"""
def __new__(cls, shape, dtype, allow_missing=False, default_value=None):
return super(FixedLenSequenceFeature, cls).__new__(
cls, shape, dtype, allow_missing, default_value)
class _ParseOpParams(object):
"""Raw parameters used by `gen_parsing_ops`.
Attributes:
sparse_keys: A list of string keys in the examples' features. The results
for these keys will be returned as `SparseTensor` objects.
sparse_types: A list of `DTypes` of the same length as `sparse_keys`. Only
`tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
(`BytesList`) are supported.
dense_keys: A list of string keys in the examples' features. The results for
these keys will be returned as `Tensor`s
dense_types: A list of DTypes of the same length as `dense_keys`. Only
`tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
(`BytesList`) are supported.
dense_defaults: A dict mapping string keys to `Tensor`s. The keys of the
dict must match the dense_keys of the feature.
dense_shapes: A list of tuples with the same length as `dense_keys`. The
shape of the data for each dense feature referenced by `dense_keys`.
Required for any input tensors identified by `dense_keys`. Must be either
fully defined, or may contain an unknown first dimension. An unknown first
dimension means the feature is treated as having a variable number of
blocks, and the output shape along this dimension is considered unknown at
graph build time. Padding is applied for minibatch elements smaller than
the maximum number of blocks for the given feature along this dimension.
dense_shapes_as_proto: dense_shapes converted to TensorShapeProto.
dense_defaults_vec: A vector of `Tensor`s containing the default values,
corresponding 1:1 with `dense_keys`.
num_features: The total number of feature keys.
"""
def __init__(self,
sparse_keys=None,
sparse_types=None,
dense_keys=None,
dense_types=None,
dense_defaults=None,
dense_shapes=None):
# Note: we use an OrderedDict for dense_defaults, to ensure consistent
# graph construction order for _e2e_test.
dense_defaults = (
collections.OrderedDict() if dense_defaults is None else dense_defaults)
sparse_keys = [] if sparse_keys is None else sparse_keys
sparse_types = [] if sparse_types is None else sparse_types
dense_keys = [] if dense_keys is None else dense_keys
dense_types = [] if dense_types is None else dense_types
dense_shapes = ([[]] *
len(dense_keys) if dense_shapes is None else dense_shapes)
self.sparse_keys = sparse_keys
self.sparse_types = [dtypes.as_dtype(t) for t in sparse_types]
self.dense_keys = dense_keys
self.dense_types = [dtypes.as_dtype(t) for t in dense_types]
self.dense_shapes = [tensor_shape.as_shape(s) for s in dense_shapes]
self.dense_defaults = dense_defaults
self._validate()
@classmethod
def from_features(cls, features, types):
"""Builds _ParseOpParams for a given set of features and allowed types.
Args:
features: A `dict` mapping feature keys to objects of a type in `types`.
types: Type of features to allow, among `FixedLenFeature`,
`VarLenFeature`, `SparseFeature`, and `FixedLenSequenceFeature`.
Returns:
A `_ParseOpParams` containing the raw parameters for `gen_parsing_ops`.
Raises:
ValueError: if `features` contains an item not in `types`, or an invalid
feature.
ValueError: if sparse and dense key sets intersect.
ValueError: if input lengths do not match up.
"""
params = cls()
if features:
# NOTE: We iterate over sorted keys to keep things deterministic.
for key in sorted(features.keys()):
feature = features[key]
if not isinstance(feature, tuple(types)):
raise ValueError("Unsupported %s %s." %
(type(feature).__name__, feature))
params._add_feature(key, feature) # pylint: disable=protected-access
return params
@property
def dense_shapes_as_proto(self):
return [shape.as_proto() for shape in self.dense_shapes]
@property
def num_features(self):
return len(self.dense_keys) + len(self.sparse_keys)
@property
def dense_defaults_vec(self):
return [
self._make_dense_default(k, s, t)
for k, s, t in zip(self.dense_keys, self.dense_shapes, self.dense_types)
]
def _make_dense_default(self, key, shape, dtype):
"""Construct the default value tensor for a specified dense feature.
Args:
key: The key string identifying the dense feature.
shape: The dense feature's shape.
dtype: The dense feature's dtype.
Returns:
A Tensor.
"""
default_value = self.dense_defaults.get(key)
if (shape.ndims is not None and shape.ndims > 0 and
shape.dims[0].value is None):
# Variable stride dense shape, the default value should be a
# scalar padding value.
if default_value is None:
default_value = ops.convert_to_tensor(
"" if dtype == dtypes.string else 0, dtype=dtype)
else:
# Reshape to a scalar to ensure user gets an error if they
# provide a tensor that's not intended to be a padding value
# (0 or 2+ elements).
key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
default_value = ops.convert_to_tensor(
default_value, dtype=dtype, name=key_name)
default_value = array_ops.reshape(default_value, [])
else:
if default_value is None:
default_value = constant_op.constant([], dtype=dtype)
elif not isinstance(default_value, ops.Tensor):
key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
default_value = ops.convert_to_tensor(
default_value, dtype=dtype, name=key_name)
default_value = array_ops.reshape(default_value, shape)
return default_value
def _add_feature(self, key, feature):
"""Adds the specified feature to this ParseOpParams."""
if isinstance(feature, VarLenFeature):
self._add_varlen_feature(key, feature)
elif isinstance(feature, SparseFeature):
self._add_sparse_feature(key, feature)
elif isinstance(feature, FixedLenFeature):
self._add_fixed_len_feature(key, feature)
elif isinstance(feature, FixedLenSequenceFeature):
self._add_fixed_len_sequence_feature(key, feature)
else:
raise ValueError("Invalid feature %s:%s." % (key, feature))
def _add_varlen_feature(self, key, feature):
"""Adds a VarLenFeature."""
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
self._add_sparse_key(key, feature.dtype)
def _add_sparse_key(self, key, dtype):
"""Adds a sparse key & dtype, checking for duplicates."""
if key in self.sparse_keys:
original_dtype = self.sparse_types[self.sparse_keys.index(key)]
if original_dtype != dtype:
raise ValueError("Conflicting type %s vs %s for feature %s." %
(original_dtype, dtype, key))
else:
self.sparse_keys.append(key)
self.sparse_types.append(dtype)
def _add_sparse_feature(self, key, feature):
"""Adds a SparseFeature."""
if not feature.index_key:
raise ValueError("Missing index_key for SparseFeature %s." % (feature,))
if not feature.value_key:
raise ValueError("Missing value_key for SparseFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
index_keys = feature.index_key
if isinstance(index_keys, str):
index_keys = [index_keys]
elif len(index_keys) > 1:
tf_logging.warning("SparseFeature is a complicated feature config "
"and should only be used after careful "
"consideration of VarLenFeature.")
for index_key in sorted(index_keys):
self._add_sparse_key(index_key, dtypes.int64)
self._add_sparse_key(feature.value_key, feature.dtype)
def _add_fixed_len_feature(self, key, feature):
"""Adds a FixedLenFeature."""
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
raise ValueError("Missing shape for feature %s." % key)
feature_tensor_shape = tensor_shape.as_shape(feature.shape)
if (feature.shape and feature_tensor_shape.ndims and
feature_tensor_shape.dims[0].value is None):
raise ValueError("First dimension of shape for feature %s unknown. "
"Consider using FixedLenSequenceFeature." % key)
if (feature.shape is not None and
not feature_tensor_shape.is_fully_defined()):
raise ValueError("All dimensions of shape for feature %s need to be "
"known but received %s." % (key, str(feature.shape)))
self.dense_keys.append(key)
self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
self.dense_types.append(feature.dtype)
if feature.default_value is not None:
self.dense_defaults[key] = feature.default_value
def _add_fixed_len_sequence_feature(self, key, feature):
"""Adds a FixedLenSequenceFeature."""
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
raise ValueError("Missing shape for feature %s." % key)
self.dense_keys.append(key)
self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
self.dense_types.append(feature.dtype)
if feature.allow_missing:
self.dense_defaults[key] = None
if feature.default_value is not None:
self.dense_defaults[key] = feature.default_value
def _validate(self):
"""Validates the features in this ParseOpParams."""
if len(self.dense_shapes) != len(self.dense_keys):
raise ValueError(
"len(self.dense_shapes) != len(self.dense_keys): %d vs %d" %
(len(self.dense_shapes), len(self.dense_keys)))
if len(self.dense_types) != len(self.dense_keys):
raise ValueError(
"len(self.dense_types) != len(self.dense_keys): %d vs %d" %
(len(self.dense_types), len(self.dense_keys)))
if len(self.sparse_types) != len(self.sparse_keys):
raise ValueError(
"len(self.sparse_types) != len(self.sparse_keys): %d vs %d" %
(len(self.sparse_types), len(self.sparse_keys)))
dense_key_set = set(self.dense_keys)
sparse_key_set = set(self.sparse_keys)
if not dense_key_set.isdisjoint(sparse_key_set):
raise ValueError(
"Dense and sparse keys must not intersect; intersection: %s" %
dense_key_set.intersection(sparse_key_set))
def _construct_sparse_tensors_for_sparse_features(features, tensor_dict):
"""Merges SparseTensors of indices and values of SparseFeatures.
Constructs new dict based on `tensor_dict`. For `SparseFeatures` in the values
of `features` expects their `index_key`s and `index_value`s to be present in
`tensor_dict` mapping to `SparseTensor`s. Constructs a single `SparseTensor`
from them, and adds it to the result with the key from `features`.
Copies other keys and values from `tensor_dict` with keys present in
`features`.
Args:
features: A `dict` mapping feature keys to `SparseFeature` values.
Values of other types will be ignored.
tensor_dict: A `dict` mapping feature keys to `Tensor` and `SparseTensor`
values. Expected to contain keys of the `SparseFeature`s' `index_key`s and
`value_key`s and mapping them to `SparseTensor`s.
Returns:
A `dict` mapping feature keys to `Tensor` and `SparseTensor` values. Similar
to `tensor_dict` except each `SparseFeature`s in `features` results in a
single `SparseTensor`.
"""
tensor_dict = dict(tensor_dict) # Do not modify argument passed in.
# Construct SparseTensors for SparseFeatures.
for key in sorted(features.keys()):
feature = features[key]
if isinstance(feature, SparseFeature):
if isinstance(feature.index_key, str):
sp_ids = tensor_dict[feature.index_key]
else:
sp_ids = [tensor_dict[index_key] for index_key in feature.index_key]
sp_values = tensor_dict[feature.value_key]
tensor_dict[key] = sparse_ops.sparse_merge(
sp_ids,
sp_values,
vocab_size=feature.size,
already_sorted=feature.already_sorted)
# Remove tensors from dictionary that were only used to construct
# SparseTensors for SparseFeature.
for key in set(tensor_dict) - set(features):
del tensor_dict[key]
return tensor_dict
# TODO(b/122887740) Switch files that use this private symbol to use new name.
_construct_sparse_tensors_for_sparse_features = \
_construct_tensors_for_composite_features
def _prepend_none_dimension(features):
@ -494,223 +79,6 @@ def _prepend_none_dimension(features):
return features
@tf_export(v1=["io.parse_example", "parse_example"])
def parse_example(serialized, features, name=None, example_names=None):
# pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors.
Parses a number of serialized [`Example`](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
protos given in `serialized`. We refer to `serialized` as a batch with
`batch_size` many entries of individual `Example` protos.
`example_names` may contain descriptive names for the corresponding serialized
protos. These may be useful for debugging purposes, but they have no effect on
the output. If not `None`, `example_names` must be the same length as
`serialized`.
This op parses serialized examples into a dictionary mapping keys to `Tensor`
and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
`SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
and `SparseFeature` is mapped to a `SparseTensor`, and each
`FixedLenFeature` is mapped to a `Tensor`.
Each `VarLenFeature` maps to a `SparseTensor` of the specified type
representing a ragged matrix. Its indices are `[batch, index]` where `batch`
identifies the example in `serialized`, and `index` is the value's index in
the list of values associated with that feature and example.
Each `SparseFeature` maps to a `SparseTensor` of the specified type
representing a Tensor of `dense_shape` `[batch_size] + SparseFeature.size`.
Its `values` come from the feature in the examples with key `value_key`.
A `values[i]` comes from a position `k` in the feature of an example at batch
entry `batch`. This positional information is recorded in `indices[i]` as
`[batch, index_0, index_1, ...]` where `index_j` is the `k-th` value of
the feature in the example at with key `SparseFeature.index_key[j]`.
In other words, we split the indices (except the first index indicating the
batch entry) of a `SparseTensor` by dimension into different features of the
`Example`. Due to its complexity a `VarLenFeature` should be preferred over a
`SparseFeature` whenever possible.
Each `FixedLenFeature` `df` maps to a `Tensor` of the specified type (or
`tf.float32` if not specified) and shape `(serialized.size(),) + df.shape`.
`FixedLenFeature` entries with a `default_value` are optional. With no default
value, we will fail if that `Feature` is missing from any example in
`serialized`.
Each `FixedLenSequenceFeature` `df` maps to a `Tensor` of the specified type
(or `tf.float32` if not specified) and shape
`(serialized.size(), None) + df.shape`.
All examples in `serialized` will be padded with `default_value` along the
second dimension.
Examples:
For example, if one expects a `tf.float32` `VarLenFeature` `ft` and three
serialized `Example`s are provided:
```
serialized = [
features
{ feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } },
features
{ feature []},
features
{ feature { key: "ft" value { float_list { value: [3.0] } } }
]
```
then the output will look like:
```python
{"ft": SparseTensor(indices=[[0, 0], [0, 1], [2, 0]],
values=[1.0, 2.0, 3.0],
dense_shape=(3, 2)) }
```
If instead a `FixedLenSequenceFeature` with `default_value = -1.0` and
`shape=[]` is used then the output will look like:
```python
{"ft": [[1.0, 2.0], [3.0, -1.0]]}
```
Given two `Example` input protos in `serialized`:
```
[
features {
feature { key: "kw" value { bytes_list { value: [ "knit", "big" ] } } }
feature { key: "gps" value { float_list { value: [] } } }
},
features {
feature { key: "kw" value { bytes_list { value: [ "emmy" ] } } }
feature { key: "dank" value { int64_list { value: [ 42 ] } } }
feature { key: "gps" value { } }
}
]
```
And arguments
```
example_names: ["input0", "input1"],
features: {
"kw": VarLenFeature(tf.string),
"dank": VarLenFeature(tf.int64),
"gps": VarLenFeature(tf.float32),
}
```
Then the output is a dictionary:
```python
{
"kw": SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=["knit", "big", "emmy"]
dense_shape=[2, 2]),
"dank": SparseTensor(
indices=[[1, 0]],
values=[42],
dense_shape=[2, 1]),
"gps": SparseTensor(
indices=[],
values=[],
dense_shape=[2, 0]),
}
```
For dense results in two serialized `Example`s:
```
[
features {
feature { key: "age" value { int64_list { value: [ 0 ] } } }
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
},
features {
feature { key: "age" value { int64_list { value: [] } } }
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
}
]
```
We can use arguments:
```
example_names: ["input0", "input1"],
features: {
"age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
"gender": FixedLenFeature([], dtype=tf.string),
}
```
And the expected output is:
```python
{
"age": [[0], [-1]],
"gender": [["f"], ["f"]],
}
```
An alternative to `VarLenFeature` to obtain a `SparseTensor` is
`SparseFeature`. For example, given two `Example` input protos in
`serialized`:
```
[
features {
feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
feature { key: "ix" value { int64_list { value: [ 3, 20 ] } } }
},
features {
feature { key: "val" value { float_list { value: [ 0.0 ] } } }
feature { key: "ix" value { int64_list { value: [ 42 ] } } }
}
]
```
And arguments
```
example_names: ["input0", "input1"],
features: {
"sparse": SparseFeature(
index_key="ix", value_key="val", dtype=tf.float32, size=100),
}
```
Then the output is a dictionary:
```python
{
"sparse": SparseTensor(
indices=[[0, 3], [0, 20], [1, 42]],
values=[0.5, -1.0, 0.0]
dense_shape=[2, 100]),
}
```
Args:
serialized: A vector (1-D Tensor) of strings, a batch of binary
serialized `Example` protos.
features: A `dict` mapping feature keys to `FixedLenFeature`,
`VarLenFeature`, and `SparseFeature` values.
name: A name for this operation (optional).
example_names: A vector (1-D Tensor) of strings (optional), the names of
the serialized protos in the batch.
Returns:
A `dict` mapping feature keys to `Tensor` and `SparseTensor` values.
Raises:
ValueError: if any feature is invalid.
"""
return parse_example_v2(serialized, features, example_names, name)
@tf_export("io.parse_example", v1=[])
def parse_example_v2(serialized, features, example_names=None, name=None):
# pylint: disable=line-too-long
@ -726,10 +94,11 @@ def parse_example_v2(serialized, features, example_names=None, name=None):
`serialized`.
This op parses serialized examples into a dictionary mapping keys to `Tensor`
and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
`SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
and `SparseFeature` is mapped to a `SparseTensor`, and each
`FixedLenFeature` is mapped to a `Tensor`.
`SparseTensor`, and `RaggedTensor` objects. `features` is a dict from keys to
`VarLenFeature`, `SparseFeature`, `RaggedFeature`, and `FixedLenFeature`
objects. Each `VarLenFeature` and `SparseFeature` is mapped to a
`SparseTensor`; each `FixedLenFeature` is mapped to a `Tensor`; and each
`RaggedFeature` is mapped to a `RaggedTensor`.
Each `VarLenFeature` maps to a `SparseTensor` of the specified type
representing a ragged matrix. Its indices are `[batch, index]` where `batch`
@ -761,6 +130,12 @@ def parse_example_v2(serialized, features, example_names=None, name=None):
All examples in `serialized` will be padded with `default_value` along the
second dimension.
Each `RaggedFeature` maps to a `RaggedTensor` of the specified type. It
is formed by stacking the `RaggedTensor` for each example, where the
`RaggedTensor` for each individual example is constructed using the tensors
specified by `RaggedTensor.values_key` and `RaggedTensor.partition`. See
the `tf.io.RaggedFeature` documentation for details and examples.
Examples:
For example, if one expects a `tf.float32` `VarLenFeature` `ft` and three
@ -910,17 +285,21 @@ def parse_example_v2(serialized, features, example_names=None, name=None):
}
```
See the `tf.io.RaggedFeature` documentation for examples showing how
`RaggedFeature` can be used to obtain `RaggedTensor`s.
Args:
serialized: A vector (1-D Tensor) of strings, a batch of binary
serialized `Example` protos.
features: A `dict` mapping feature keys to `FixedLenFeature`,
`VarLenFeature`, and `SparseFeature` values.
`VarLenFeature`, `SparseFeature`, and `RaggedFeature` values.
example_names: A vector (1-D Tensor) of strings (optional), the names of
the serialized protos in the batch.
name: A name for this operation (optional).
Returns:
A `dict` mapping feature keys to `Tensor` and `SparseTensor` values.
A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and
`RaggedTensor` values.
Raises:
ValueError: if any feature is invalid.
@ -928,12 +307,21 @@ def parse_example_v2(serialized, features, example_names=None, name=None):
if not features:
raise ValueError("Missing: features was %s." % features)
features = _prepend_none_dimension(features)
params = _ParseOpParams.from_features(
features,
[VarLenFeature, SparseFeature, FixedLenFeature, FixedLenSequenceFeature])
params = _ParseOpParams.from_features(features, [
VarLenFeature, SparseFeature, FixedLenFeature, FixedLenSequenceFeature,
RaggedFeature
])
outputs = _parse_example_raw(serialized, example_names, params, name=name)
return _construct_sparse_tensors_for_sparse_features(features, outputs)
return _construct_tensors_for_composite_features(features, outputs)
@tf_export(v1=["io.parse_example", "parse_example"])
def parse_example(serialized, features, name=None, example_names=None):
return parse_example_v2(serialized, features, example_names, name)
parse_example.__doc__ = parse_example_v2.__doc__
def _parse_example_raw(serialized, names, params, name):
@ -948,32 +336,57 @@ def _parse_example_raw(serialized, names, params, name):
name: A name for this operation (optional).
Returns:
A `dict` mapping keys to `Tensor`s and `SparseTensor`s.
A `dict` mapping keys to `Tensor`s and `SparseTensor`s and `RaggedTensor`s.
"""
if params.num_features == 0:
raise ValueError("Must provide at least one feature key")
with ops.name_scope(name, "ParseExample", [serialized, names]):
names = [] if names is None else names
outputs = gen_parsing_ops.parse_example(
serialized=serialized,
names=names,
dense_defaults=params.dense_defaults_vec,
sparse_keys=params.sparse_keys,
sparse_types=params.sparse_types,
dense_keys=params.dense_keys,
dense_shapes=params.dense_shapes_as_proto,
name=name)
if compat.forward_compatible(2019, 10, 16) or params.ragged_keys:
serialized = ops.convert_to_tensor(serialized, name="serialized")
if params.ragged_keys and serialized.shape.ndims is None:
raise ValueError("serialized must have statically-known rank to "
"parse ragged features.")
outputs = gen_parsing_ops.parse_example_v2(
serialized=serialized,
names=names,
sparse_keys=params.sparse_keys,
dense_keys=params.dense_keys,
ragged_keys=params.ragged_keys,
dense_defaults=params.dense_defaults_vec,
num_sparse=len(params.sparse_keys),
sparse_types=params.sparse_types,
ragged_value_types=params.ragged_value_types,
ragged_split_types=params.ragged_split_types,
dense_shapes=params.dense_shapes_as_proto,
name=name)
(sparse_indices, sparse_values, sparse_shapes, dense_values,
ragged_values, ragged_row_splits) = outputs
# pylint: disable=protected-access
ragged_tensors = parsing_config._build_ragged_tensors(
serialized.shape, ragged_values, ragged_row_splits)
else:
outputs = gen_parsing_ops.parse_example(
serialized=serialized,
names=names,
dense_defaults=params.dense_defaults_vec,
sparse_keys=params.sparse_keys,
sparse_types=params.sparse_types,
dense_keys=params.dense_keys,
dense_shapes=params.dense_shapes_as_proto,
name=name)
(sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs
(sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs
ragged_tensors = []
sparse_tensors = [
sparse_tensor.SparseTensor(ix, val, shape) for (ix, val, shape)
in zip(sparse_indices, sparse_values, sparse_shapes)]
return dict(
zip(params.sparse_keys + params.dense_keys,
sparse_tensors + dense_values))
zip(params.sparse_keys + params.dense_keys + params.ragged_keys,
sparse_tensors + dense_values + ragged_tensors))
@tf_export(v1=["io.parse_single_example", "parse_single_example"])
@ -996,12 +409,10 @@ def parse_single_example(serialized, features, name=None, example_names=None):
Args:
serialized: A scalar string Tensor, a single serialized Example.
See `_parse_single_example_raw` documentation for more details.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values.
name: A name for this operation (optional).
example_names: (Optional) A scalar string Tensor, the associated name.
See `_parse_single_example_raw` documentation for more details.
Returns:
A `dict` mapping feature keys to `Tensor` and `SparseTensor` values.
@ -1038,11 +449,9 @@ def parse_single_example_v2_unoptimized(
Args:
serialized: A scalar string Tensor, a single serialized Example.
See `_parse_single_example_raw` documentation for more details.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values.
example_names: (Optional) A scalar string Tensor, the associated name.
See `_parse_single_example_raw` documentation for more details.
name: A name for this operation (optional).
Returns:
@ -1053,6 +462,15 @@ def parse_single_example_v2_unoptimized(
"""
if not features:
raise ValueError("Missing features.")
any_ragged_features = any(
isinstance(f, RaggedFeature)
for f in features.values())
if compat.forward_compatible(2019, 10, 16) or any_ragged_features:
with ops.name_scope(name, "ParseSingleExample",
[serialized, example_names]):
serialized = ops.convert_to_tensor(serialized, name="serialized")
serialized = _assert_scalar(serialized, "serialized")
return parse_example_v2(serialized, features, example_names, name)
if example_names is None:
return parse_single_example_v2(serialized, features, name)
features = _prepend_none_dimension(features)
@ -1060,7 +478,7 @@ def parse_single_example_v2_unoptimized(
features,
[VarLenFeature, FixedLenFeature, FixedLenSequenceFeature, SparseFeature])
outputs = _parse_single_example_raw(serialized, example_names, params, name)
return _construct_sparse_tensors_for_sparse_features(features, outputs)
return _construct_tensors_for_composite_features(features, outputs)
def _parse_single_example_raw(serialized, names, params, name=None):
@ -1103,6 +521,8 @@ def _parse_single_example_raw(serialized, names, params, name=None):
array_ops.slice(
outputs[s].dense_shape, [1], [-1],
name="Squeeze_Shape_%s" % s_name))
for s in params.ragged_keys:
outputs[s] = outputs[s].values
return outputs
@ -1739,7 +1159,7 @@ def parse_single_example_v2(serialized, features, name=None):
features,
[VarLenFeature, FixedLenFeature, FixedLenSequenceFeature, SparseFeature])
outputs = _parse_single_example_v2_raw(serialized, params, name)
return _construct_sparse_tensors_for_sparse_features(features, outputs)
return _construct_tensors_for_composite_features(features, outputs)
def _parse_single_example_v2_raw(serialized, params, name):

View File

@ -252,7 +252,11 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output_graph_def.ParseFromString(f.read())
_ = importer.import_graph_def(output_graph_def, name="")
self.assertEqual(8, len(output_graph_def.node))
if any(u"ParseExampleV2" in node.name for node in output_graph_def.node):
expected_node_count = 10
else:
expected_node_count = 8
self.assertEqual(expected_node_count, len(output_graph_def.node))
for node in output_graph_def.node:
self.assertNotEqual("VariableV2", node.op)
self.assertNotEqual("Variable", node.op)

View File

@ -27,7 +27,7 @@ def extract_example_parser_configuration(parse_example_op, sess):
"""Returns an ExampleParserConfig proto.
Args:
parse_example_op: A ParseExample `Operation`
parse_example_op: A ParseExample or ParseExampleV2 `Operation`
sess: A tf.compat.v1.Session needed to obtain some configuration values.
Returns:
A ExampleParserConfig proto.
@ -35,6 +35,16 @@ def extract_example_parser_configuration(parse_example_op, sess):
Raises:
ValueError: If attributes are inconsistent.
"""
if parse_example_op.type == "ParseExample":
return _extract_from_parse_example(parse_example_op, sess)
elif parse_example_op.type == "ParseExampleV2":
return _extract_from_parse_example_v2(parse_example_op, sess)
else:
raise ValueError("Unexpeected op type: %s" % parse_example_op.type)
def _extract_from_parse_example(parse_example_op, sess):
"""Extract ExampleParserConfig from ParseExample op."""
config = example_parser_configuration_pb2.ExampleParserConfiguration()
num_sparse = parse_example_op.get_attr("Nsparse")
@ -120,3 +130,80 @@ def extract_example_parser_configuration(parse_example_op, sess):
sparse_shapes_start + i].name
return config
def _extract_from_parse_example_v2(parse_example_op, sess):
"""Extract ExampleParserConfig from ParseExampleV2 op."""
config = example_parser_configuration_pb2.ExampleParserConfiguration()
dense_types = parse_example_op.get_attr("Tdense")
num_sparse = parse_example_op.get_attr("num_sparse")
sparse_types = parse_example_op.get_attr("sparse_types")
ragged_value_types = parse_example_op.get_attr("ragged_value_types")
ragged_split_types = parse_example_op.get_attr("ragged_split_types")
dense_shapes = parse_example_op.get_attr("dense_shapes")
num_dense = len(dense_types)
num_ragged = len(ragged_value_types)
assert len(ragged_value_types) == len(ragged_split_types)
assert len(parse_example_op.inputs) == 5 + num_dense
# Skip over the serialized input, and the names input.
fetched = sess.run(parse_example_op.inputs[2:])
sparse_keys = fetched[0].tolist()
dense_keys = fetched[1].tolist()
ragged_keys = fetched[2].tolist()
dense_defaults = fetched[3:]
assert len(sparse_keys) == num_sparse
assert len(dense_keys) == num_dense
assert len(ragged_keys) == num_ragged
# Output tensor indices.
sparse_indices_start = 0
sparse_values_start = num_sparse
sparse_shapes_start = sparse_values_start + num_sparse
dense_values_start = sparse_shapes_start + num_sparse
ragged_values_start = dense_values_start + num_dense
ragged_row_splits_start = ragged_values_start + num_ragged
# Dense features.
for i in range(num_dense):
key = dense_keys[i]
feature_config = config.feature_map[key]
# Convert the default value numpy array fetched from the session run
# into a TensorProto.
fixed_config = feature_config.fixed_len_feature
fixed_config.default_value.CopyFrom(
tensor_util.make_tensor_proto(dense_defaults[i]))
# Convert the shape from the attributes
# into a TensorShapeProto.
fixed_config.shape.CopyFrom(
tensor_shape.TensorShape(dense_shapes[i]).as_proto())
fixed_config.dtype = dense_types[i].as_datatype_enum
# Get the output tensor name.
fixed_config.values_output_tensor_name = parse_example_op.outputs[
dense_values_start + i].name
# Sparse features.
for i in range(num_sparse):
key = sparse_keys[i]
feature_config = config.feature_map[key]
var_len_feature = feature_config.var_len_feature
var_len_feature.dtype = sparse_types[i].as_datatype_enum
var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
sparse_indices_start + i].name
var_len_feature.values_output_tensor_name = parse_example_op.outputs[
sparse_values_start + i].name
var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
sparse_shapes_start + i].name
if num_ragged != 0:
del ragged_values_start # unused
del ragged_row_splits_start # unused
raise ValueError("Ragged features are not yet supported by "
"example_parser_configuration.proto")
return config

View File

@ -28,7 +28,7 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util.example_parser_configuration import extract_example_parser_configuration
BASIC_PROTO = """
EXPECTED_CONFIG_V1 = """
feature_map {
key: "x"
value {
@ -66,23 +66,32 @@ feature_map {
"""
EXPECTED_CONFIG_V2 = EXPECTED_CONFIG_V1.replace(
'ParseExample/ParseExample:', 'ParseExample/ParseExampleV2:')
class ExampleParserConfigurationTest(test.TestCase):
def getExpectedConfig(self, op_type):
expected = example_parser_configuration_pb2.ExampleParserConfiguration()
if op_type == 'ParseExampleV2':
text_format.Parse(EXPECTED_CONFIG_V2, expected)
else:
text_format.Parse(EXPECTED_CONFIG_V1, expected)
return expected
def testBasic(self):
golden_config = example_parser_configuration_pb2.ExampleParserConfiguration(
)
text_format.Parse(BASIC_PROTO, golden_config)
with session.Session() as sess:
examples = array_ops.placeholder(dtypes.string, shape=[1])
feature_to_type = {
'x': parsing_ops.FixedLenFeature([1], dtypes.float32, 33.0),
'y': parsing_ops.VarLenFeature(dtypes.string)
}
_ = parsing_ops.parse_example(examples, feature_to_type)
parse_example_op = sess.graph.get_operation_by_name(
'ParseExample/ParseExample')
result = parsing_ops.parse_example(examples, feature_to_type)
parse_example_op = result['x'].op
config = extract_example_parser_configuration(parse_example_op, sess)
self.assertProtoEquals(golden_config, config)
expected = self.getExpectedConfig(parse_example_op.type)
self.assertProtoEquals(expected, config)
if __name__ == '__main__':

View File

@ -1,7 +1,7 @@
path: "tensorflow.FixedLenFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "default_value"

View File

@ -1,7 +1,7 @@
path: "tensorflow.FixedLenSequenceFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenSequenceFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenSequenceFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "allow_missing"

View File

@ -1,7 +1,7 @@
path: "tensorflow.SparseFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.SparseFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.SparseFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "already_sorted"

View File

@ -1,7 +1,7 @@
path: "tensorflow.VarLenFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.VarLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.VarLenFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "dtype"

View File

@ -1,7 +1,7 @@
path: "tensorflow.io.FixedLenFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "default_value"

View File

@ -1,7 +1,7 @@
path: "tensorflow.io.FixedLenSequenceFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenSequenceFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenSequenceFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "allow_missing"

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.RowLengths"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RowLengths\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.RowLimits"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RowLimits\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.RowSplits"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RowSplits\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.RowStarts"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RowStarts\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.UniformRowLength"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.UniformRowLength\'>"
is_instance: "<type \'tuple\'>"
member {
name: "length"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.ValueRowIds"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.ValueRowIds\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,59 @@
path: "tensorflow.io.RaggedFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RaggedFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.RaggedFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "RowLengths"
mtype: "<type \'type\'>"
}
member {
name: "RowLimits"
mtype: "<type \'type\'>"
}
member {
name: "RowSplits"
mtype: "<type \'type\'>"
}
member {
name: "RowStarts"
mtype: "<type \'type\'>"
}
member {
name: "UniformRowLength"
mtype: "<type \'type\'>"
}
member {
name: "ValueRowIds"
mtype: "<type \'type\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "partitions"
mtype: "<type \'property\'>"
}
member {
name: "row_splits_dtype"
mtype: "<type \'property\'>"
}
member {
name: "validate"
mtype: "<type \'property\'>"
}
member {
name: "value_key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -1,7 +1,7 @@
path: "tensorflow.io.SparseFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.SparseFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.SparseFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "already_sorted"

View File

@ -1,7 +1,7 @@
path: "tensorflow.io.VarLenFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.VarLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.VarLenFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "dtype"

View File

@ -20,6 +20,10 @@ tf_module {
name: "QueueBase"
mtype: "<type \'type\'>"
}
member {
name: "RaggedFeature"
mtype: "<type \'type\'>"
}
member {
name: "RandomShuffleQueue"
mtype: "<type \'type\'>"

View File

@ -1,7 +1,7 @@
path: "tensorflow.io.FixedLenFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "default_value"

View File

@ -1,7 +1,7 @@
path: "tensorflow.io.FixedLenSequenceFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenSequenceFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.FixedLenSequenceFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "allow_missing"

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.RowLengths"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RowLengths\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.RowLimits"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RowLimits\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.RowSplits"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RowSplits\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.RowStarts"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RowStarts\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.UniformRowLength"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.UniformRowLength\'>"
is_instance: "<type \'tuple\'>"
member {
name: "length"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,18 @@
path: "tensorflow.io.RaggedFeature.ValueRowIds"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.ValueRowIds\'>"
is_instance: "<type \'tuple\'>"
member {
name: "key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,59 @@
path: "tensorflow.io.RaggedFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_config.RaggedFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.RaggedFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "RowLengths"
mtype: "<type \'type\'>"
}
member {
name: "RowLimits"
mtype: "<type \'type\'>"
}
member {
name: "RowSplits"
mtype: "<type \'type\'>"
}
member {
name: "RowStarts"
mtype: "<type \'type\'>"
}
member {
name: "UniformRowLength"
mtype: "<type \'type\'>"
}
member {
name: "ValueRowIds"
mtype: "<type \'type\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "partitions"
mtype: "<type \'property\'>"
}
member {
name: "row_splits_dtype"
mtype: "<type \'property\'>"
}
member {
name: "validate"
mtype: "<type \'property\'>"
}
member {
name: "value_key"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -1,7 +1,7 @@
path: "tensorflow.io.SparseFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.SparseFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.SparseFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "already_sorted"

View File

@ -1,7 +1,7 @@
path: "tensorflow.io.VarLenFeature"
tf_class {
is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.VarLenFeature\'>"
is_instance: "<class \'tensorflow.python.ops.parsing_config.VarLenFeature\'>"
is_instance: "<type \'tuple\'>"
member {
name: "dtype"

View File

@ -8,6 +8,10 @@ tf_module {
name: "FixedLenSequenceFeature"
mtype: "<type \'type\'>"
}
member {
name: "RaggedFeature"
mtype: "<type \'type\'>"
}
member {
name: "SparseFeature"
mtype: "<type \'type\'>"