diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 2e0e7786d91..bd1e94fd79e 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function import abc import enum import functools +import sys import threading import warnings import weakref @@ -599,20 +600,20 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): try: flattened_values = nest.flatten_up_to(output_types, values) except (TypeError, ValueError): - raise TypeError( + six.reraise(TypeError, TypeError( "`generator` yielded an element that did not match the expected " "structure. The expected structure was %s, but the yielded " - "element was %s." % (output_types, values)) + "element was %s." % (output_types, values)), sys.exc_info()[2]) ret_arrays = [] for ret, dtype in zip(flattened_values, flattened_types): try: ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access ret, dtype=dtype.as_numpy_dtype)) except (TypeError, ValueError): - raise TypeError( + six.reraise(TypeError, TypeError( "`generator` yielded an element that could not be converted to " "the expected type. The expected type was %s, but the yielded " - "element was %s." % (dtype.name, ret)) + "element was %s." % (dtype.name, ret)), sys.exc_info()[2]) # Additional type and shape checking to ensure that the components # of the generated element match the `output_types` and `output_shapes` @@ -2682,8 +2683,11 @@ class StructuredFunctionWrapper(object): try: self._output_structure = structure.type_spec_from_value(ret) except (ValueError, TypeError): - raise TypeError("Unsupported return value from function passed to " - "%s: %s." % (transformation_name, ret)) + six.reraise( + TypeError, + TypeError("Unsupported return value from function passed to " + "%s: %s." % (transformation_name, ret)), + sys.exc_info()[2]) return ret if use_legacy_function: @@ -3287,13 +3291,16 @@ def _padded_shape_to_tensor(padded_shape, input_component_shape): # machinery. ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64) if ret.shape.dims is not None and len(ret.shape.dims) != 1: - raise ValueError( + six.reraise(ValueError, ValueError( "Padded shape %s must be a 1-D tensor of tf.int64 values, but its " - "shape was %s." % (padded_shape, ret.shape)) + "shape was %s." % (padded_shape, ret.shape)), sys.exc_info()[2]) if ret.dtype != dtypes.int64: - raise TypeError( - "Padded shape %s must be a 1-D tensor of tf.int64 values, but its " - "element type was %s." % (padded_shape, ret.dtype.name)) + six.reraise( + TypeError, + TypeError( + "Padded shape %s must be a 1-D tensor of tf.int64 values, but " + "its element type was %s." % (padded_shape, ret.dtype.name)), + sys.exc_info()[2]) padded_shape_as_shape = tensor_util.constant_value_as_shape(ret) if not _is_padded_shape_compatible_with(padded_shape_as_shape, diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 645d95bdd68..f9ad93b86a1 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import type_spec from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import deprecation +from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.tf_export import tf_export @@ -410,7 +411,7 @@ def type_spec_from_value(element): if spec is not None: return spec - if isinstance(element, dict): + if isinstance(element, collections_abc.Mapping): # We create a shallow copy in an attempt to preserve the key order. # # Note that we do not guarantee that the key order is preserved, which is @@ -418,10 +419,11 @@ def type_spec_from_value(element): # `type_spec_from_value` should not assume that the key order of a `dict` # in the returned nested structure matches the key order of the # corresponding `dict` in the input value. - result = element.copy() - for k in element: - result[k] = type_spec_from_value(element[k]) - return result + if isinstance(element, collections.defaultdict): + ctor = lambda items: type(element)(element.default_factory, items) + else: + ctor = type(element) + return ctor([(k, type_spec_from_value(v)) for k, v in element.items()]) if isinstance(element, tuple): if hasattr(element, "_fields") and isinstance( diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index 290dc99df27..4127d694ab8 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -39,6 +39,7 @@ from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.platform import test +from tensorflow.python.util.compat import collections_abc # NOTE(mrry): Arguments of parameterized tests are lifted into lambdas to make @@ -738,6 +739,28 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase, # Note: shape was automatically converted from a list to a TensorShape. self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5])) + def testCustomMapping(self): + elem = CustomMap(foo=constant_op.constant(37.)) + spec = structure.type_spec_from_value(elem) + self.assertIsInstance(spec, CustomMap) + self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32)) + + +class CustomMap(collections_abc.Mapping): + """Custom, immutable map.""" + + def __init__(self, *args, **kwargs): + self.__dict__.update(dict(*args, **kwargs)) + + def __getitem__(self, x): + return self.__dict__[x] + + def __iter__(self): + return iter(self.__dict__) + + def __len__(self): + return len(self.__dict__) + if __name__ == "__main__": test.main()