Merge pull request #32669 from tensorflow/ggadde-cp-19
[r2.0-CherryPick]:[tf.data] Avoid double conversion to a tensor during input normalizat…
This commit is contained in:
commit
2646d23074
@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -83,24 +84,31 @@ def normalize_element(element):
|
||||
components = nest.flatten(element)
|
||||
normalized_components = []
|
||||
with ops.name_scope("normalize_element"):
|
||||
# Imported here to avoid circular dependency
|
||||
# Imported here to avoid circular dependency.
|
||||
from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top
|
||||
for i, t in enumerate(components):
|
||||
spec = type_spec_from_value(t)
|
||||
if isinstance(spec, sparse_tensor.SparseTensorSpec):
|
||||
normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
|
||||
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
|
||||
normalized_components.append(
|
||||
ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
t, name="component_%d" % i))
|
||||
elif isinstance(
|
||||
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
|
||||
normalized_components.append(t)
|
||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
||||
normalized_components.append(t)
|
||||
else:
|
||||
try:
|
||||
spec = type_spec_from_value(t, use_fallback=False)
|
||||
except TypeError:
|
||||
# TypeError indicates it was not possible to compute a `TypeSpec` for
|
||||
# the value. As a fallback try converting the value to a tensor.
|
||||
normalized_components.append(
|
||||
ops.convert_to_tensor(t, name="component_%d" % i))
|
||||
else:
|
||||
if isinstance(spec, sparse_tensor.SparseTensorSpec):
|
||||
normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
|
||||
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
|
||||
normalized_components.append(
|
||||
ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||
t, name="component_%d" % i))
|
||||
elif isinstance(
|
||||
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
|
||||
normalized_components.append(t)
|
||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
||||
normalized_components.append(t)
|
||||
else:
|
||||
normalized_components.append(
|
||||
ops.convert_to_tensor(t, name="component_%d" % i))
|
||||
return nest.pack_sequence_as(element, normalized_components)
|
||||
|
||||
|
||||
@ -392,11 +400,13 @@ def are_compatible(spec1, spec2):
|
||||
return True
|
||||
|
||||
|
||||
def type_spec_from_value(element):
|
||||
def type_spec_from_value(element, use_fallback=True):
|
||||
"""Creates a type specification for the given value.
|
||||
|
||||
Args:
|
||||
element: The element to create the type specification for.
|
||||
use_fallback: Whether to fall back to converting the element to a tensor
|
||||
in order to compute its `TypeSpec`.
|
||||
|
||||
Returns:
|
||||
A nested structure of `TypeSpec`s that represents the type specification
|
||||
@ -432,14 +442,16 @@ def type_spec_from_value(element):
|
||||
# `element` is not a namedtuple
|
||||
return tuple([type_spec_from_value(v) for v in element])
|
||||
|
||||
# Fallback: try converting value to a tensor.
|
||||
try:
|
||||
tensor = ops.convert_to_tensor(element)
|
||||
spec = type_spec_from_value(tensor)
|
||||
if spec is not None:
|
||||
return spec
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if use_fallback:
|
||||
# As a fallback try converting the element to a tensor.
|
||||
try:
|
||||
tensor = ops.convert_to_tensor(element)
|
||||
spec = type_spec_from_value(tensor)
|
||||
if spec is not None:
|
||||
return spec
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.vlog(
|
||||
3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e))
|
||||
|
||||
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
||||
(element, type(element).__name__))
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
@ -483,8 +484,9 @@ def type_spec_from_value(value):
|
||||
spec = _type_spec_from_value(tensor)
|
||||
if spec is not None:
|
||||
return spec
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.vlog(
|
||||
3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e))
|
||||
|
||||
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
||||
(value, type(value).__name__))
|
||||
|
Loading…
Reference in New Issue
Block a user