Support non-dict maps in tf.data.util.structure.type_spec_from_value.

PiperOrigin-RevId: 266934779
This commit is contained in:
A. Unique TensorFlower 2019-09-03 08:32:24 -07:00 committed by TensorFlower Gardener
parent d70a2cf2ab
commit e21b2570a3
3 changed files with 48 additions and 16 deletions

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import abc import abc
import enum import enum
import functools import functools
import sys
import threading import threading
import warnings import warnings
import weakref import weakref
@ -599,20 +600,20 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
try: try:
flattened_values = nest.flatten_up_to(output_types, values) flattened_values = nest.flatten_up_to(output_types, values)
except (TypeError, ValueError): except (TypeError, ValueError):
raise TypeError( six.reraise(TypeError, TypeError(
"`generator` yielded an element that did not match the expected " "`generator` yielded an element that did not match the expected "
"structure. The expected structure was %s, but the yielded " "structure. The expected structure was %s, but the yielded "
"element was %s." % (output_types, values)) "element was %s." % (output_types, values)), sys.exc_info()[2])
ret_arrays = [] ret_arrays = []
for ret, dtype in zip(flattened_values, flattened_types): for ret, dtype in zip(flattened_values, flattened_types):
try: try:
ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
ret, dtype=dtype.as_numpy_dtype)) ret, dtype=dtype.as_numpy_dtype))
except (TypeError, ValueError): except (TypeError, ValueError):
raise TypeError( six.reraise(TypeError, TypeError(
"`generator` yielded an element that could not be converted to " "`generator` yielded an element that could not be converted to "
"the expected type. The expected type was %s, but the yielded " "the expected type. The expected type was %s, but the yielded "
"element was %s." % (dtype.name, ret)) "element was %s." % (dtype.name, ret)), sys.exc_info()[2])
# Additional type and shape checking to ensure that the components # Additional type and shape checking to ensure that the components
# of the generated element match the `output_types` and `output_shapes` # of the generated element match the `output_types` and `output_shapes`
@ -2682,8 +2683,11 @@ class StructuredFunctionWrapper(object):
try: try:
self._output_structure = structure.type_spec_from_value(ret) self._output_structure = structure.type_spec_from_value(ret)
except (ValueError, TypeError): except (ValueError, TypeError):
raise TypeError("Unsupported return value from function passed to " six.reraise(
"%s: %s." % (transformation_name, ret)) TypeError,
TypeError("Unsupported return value from function passed to "
"%s: %s." % (transformation_name, ret)),
sys.exc_info()[2])
return ret return ret
if use_legacy_function: if use_legacy_function:
@ -3287,13 +3291,16 @@ def _padded_shape_to_tensor(padded_shape, input_component_shape):
# machinery. # machinery.
ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64) ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
if ret.shape.dims is not None and len(ret.shape.dims) != 1: if ret.shape.dims is not None and len(ret.shape.dims) != 1:
raise ValueError( six.reraise(ValueError, ValueError(
"Padded shape %s must be a 1-D tensor of tf.int64 values, but its " "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
"shape was %s." % (padded_shape, ret.shape)) "shape was %s." % (padded_shape, ret.shape)), sys.exc_info()[2])
if ret.dtype != dtypes.int64: if ret.dtype != dtypes.int64:
raise TypeError( six.reraise(
"Padded shape %s must be a 1-D tensor of tf.int64 values, but its " TypeError,
"element type was %s." % (padded_shape, ret.dtype.name)) TypeError(
"Padded shape %s must be a 1-D tensor of tf.int64 values, but "
"its element type was %s." % (padded_shape, ret.dtype.name)),
sys.exc_info()[2])
padded_shape_as_shape = tensor_util.constant_value_as_shape(ret) padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
if not _is_padded_shape_compatible_with(padded_shape_as_shape, if not _is_padded_shape_compatible_with(padded_shape_as_shape,

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import type_spec
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -410,7 +411,7 @@ def type_spec_from_value(element):
if spec is not None: if spec is not None:
return spec return spec
if isinstance(element, dict): if isinstance(element, collections_abc.Mapping):
# We create a shallow copy in an attempt to preserve the key order. # We create a shallow copy in an attempt to preserve the key order.
# #
# Note that we do not guarantee that the key order is preserved, which is # Note that we do not guarantee that the key order is preserved, which is
@ -418,10 +419,11 @@ def type_spec_from_value(element):
# `type_spec_from_value` should not assume that the key order of a `dict` # `type_spec_from_value` should not assume that the key order of a `dict`
# in the returned nested structure matches the key order of the # in the returned nested structure matches the key order of the
# corresponding `dict` in the input value. # corresponding `dict` in the input value.
result = element.copy() if isinstance(element, collections.defaultdict):
for k in element: ctor = lambda items: type(element)(element.default_factory, items)
result[k] = type_spec_from_value(element[k]) else:
return result ctor = type(element)
return ctor([(k, type_spec_from_value(v)) for k, v in element.items()])
if isinstance(element, tuple): if isinstance(element, tuple):
if hasattr(element, "_fields") and isinstance( if hasattr(element, "_fields") and isinstance(

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util.compat import collections_abc
# NOTE(mrry): Arguments of parameterized tests are lifted into lambdas to make # NOTE(mrry): Arguments of parameterized tests are lifted into lambdas to make
@ -738,6 +739,28 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
# Note: shape was automatically converted from a list to a TensorShape. # Note: shape was automatically converted from a list to a TensorShape.
self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5])) self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
def testCustomMapping(self):
elem = CustomMap(foo=constant_op.constant(37.))
spec = structure.type_spec_from_value(elem)
self.assertIsInstance(spec, CustomMap)
self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32))
class CustomMap(collections_abc.Mapping):
"""Custom, immutable map."""
def __init__(self, *args, **kwargs):
self.__dict__.update(dict(*args, **kwargs))
def __getitem__(self, x):
return self.__dict__[x]
def __iter__(self):
return iter(self.__dict__)
def __len__(self):
return len(self.__dict__)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()