Support non-dict maps in tf.data.util.structure.type_spec_from_value.
PiperOrigin-RevId: 266934779
This commit is contained in:
parent
d70a2cf2ab
commit
e21b2570a3
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import abc
|
||||
import enum
|
||||
import functools
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
import weakref
|
||||
@ -599,20 +600,20 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
try:
|
||||
flattened_values = nest.flatten_up_to(output_types, values)
|
||||
except (TypeError, ValueError):
|
||||
raise TypeError(
|
||||
six.reraise(TypeError, TypeError(
|
||||
"`generator` yielded an element that did not match the expected "
|
||||
"structure. The expected structure was %s, but the yielded "
|
||||
"element was %s." % (output_types, values))
|
||||
"element was %s." % (output_types, values)), sys.exc_info()[2])
|
||||
ret_arrays = []
|
||||
for ret, dtype in zip(flattened_values, flattened_types):
|
||||
try:
|
||||
ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
|
||||
ret, dtype=dtype.as_numpy_dtype))
|
||||
except (TypeError, ValueError):
|
||||
raise TypeError(
|
||||
six.reraise(TypeError, TypeError(
|
||||
"`generator` yielded an element that could not be converted to "
|
||||
"the expected type. The expected type was %s, but the yielded "
|
||||
"element was %s." % (dtype.name, ret))
|
||||
"element was %s." % (dtype.name, ret)), sys.exc_info()[2])
|
||||
|
||||
# Additional type and shape checking to ensure that the components
|
||||
# of the generated element match the `output_types` and `output_shapes`
|
||||
@ -2682,8 +2683,11 @@ class StructuredFunctionWrapper(object):
|
||||
try:
|
||||
self._output_structure = structure.type_spec_from_value(ret)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError("Unsupported return value from function passed to "
|
||||
"%s: %s." % (transformation_name, ret))
|
||||
six.reraise(
|
||||
TypeError,
|
||||
TypeError("Unsupported return value from function passed to "
|
||||
"%s: %s." % (transformation_name, ret)),
|
||||
sys.exc_info()[2])
|
||||
return ret
|
||||
|
||||
if use_legacy_function:
|
||||
@ -3287,13 +3291,16 @@ def _padded_shape_to_tensor(padded_shape, input_component_shape):
|
||||
# machinery.
|
||||
ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
|
||||
if ret.shape.dims is not None and len(ret.shape.dims) != 1:
|
||||
raise ValueError(
|
||||
six.reraise(ValueError, ValueError(
|
||||
"Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
|
||||
"shape was %s." % (padded_shape, ret.shape))
|
||||
"shape was %s." % (padded_shape, ret.shape)), sys.exc_info()[2])
|
||||
if ret.dtype != dtypes.int64:
|
||||
raise TypeError(
|
||||
"Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
|
||||
"element type was %s." % (padded_shape, ret.dtype.name))
|
||||
six.reraise(
|
||||
TypeError,
|
||||
TypeError(
|
||||
"Padded shape %s must be a 1-D tensor of tf.int64 values, but "
|
||||
"its element type was %s." % (padded_shape, ret.dtype.name)),
|
||||
sys.exc_info()[2])
|
||||
padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
|
||||
|
||||
if not _is_padded_shape_compatible_with(padded_shape_as_shape,
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -410,7 +411,7 @@ def type_spec_from_value(element):
|
||||
if spec is not None:
|
||||
return spec
|
||||
|
||||
if isinstance(element, dict):
|
||||
if isinstance(element, collections_abc.Mapping):
|
||||
# We create a shallow copy in an attempt to preserve the key order.
|
||||
#
|
||||
# Note that we do not guarantee that the key order is preserved, which is
|
||||
@ -418,10 +419,11 @@ def type_spec_from_value(element):
|
||||
# `type_spec_from_value` should not assume that the key order of a `dict`
|
||||
# in the returned nested structure matches the key order of the
|
||||
# corresponding `dict` in the input value.
|
||||
result = element.copy()
|
||||
for k in element:
|
||||
result[k] = type_spec_from_value(element[k])
|
||||
return result
|
||||
if isinstance(element, collections.defaultdict):
|
||||
ctor = lambda items: type(element)(element.default_factory, items)
|
||||
else:
|
||||
ctor = type(element)
|
||||
return ctor([(k, type_spec_from_value(v)) for k, v in element.items()])
|
||||
|
||||
if isinstance(element, tuple):
|
||||
if hasattr(element, "_fields") and isinstance(
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
|
||||
|
||||
# NOTE(mrry): Arguments of parameterized tests are lifted into lambdas to make
|
||||
@ -738,6 +739,28 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
# Note: shape was automatically converted from a list to a TensorShape.
|
||||
self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
|
||||
|
||||
def testCustomMapping(self):
|
||||
elem = CustomMap(foo=constant_op.constant(37.))
|
||||
spec = structure.type_spec_from_value(elem)
|
||||
self.assertIsInstance(spec, CustomMap)
|
||||
self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32))
|
||||
|
||||
|
||||
class CustomMap(collections_abc.Mapping):
|
||||
"""Custom, immutable map."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.__dict__.update(dict(*args, **kwargs))
|
||||
|
||||
def __getitem__(self, x):
|
||||
return self.__dict__[x]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.__dict__)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__dict__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user