Merge pull request #32669 from tensorflow/ggadde-cp-19
[r2.0-CherryPick]:[tf.data] Avoid double conversion to a tensor during input normalizat…
This commit is contained in:
commit
2646d23074
@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_spec
|
|||||||
from tensorflow.python.framework import type_spec
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -83,24 +84,31 @@ def normalize_element(element):
|
|||||||
components = nest.flatten(element)
|
components = nest.flatten(element)
|
||||||
normalized_components = []
|
normalized_components = []
|
||||||
with ops.name_scope("normalize_element"):
|
with ops.name_scope("normalize_element"):
|
||||||
# Imported here to avoid circular dependency
|
# Imported here to avoid circular dependency.
|
||||||
from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top
|
from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top
|
||||||
for i, t in enumerate(components):
|
for i, t in enumerate(components):
|
||||||
spec = type_spec_from_value(t)
|
try:
|
||||||
if isinstance(spec, sparse_tensor.SparseTensorSpec):
|
spec = type_spec_from_value(t, use_fallback=False)
|
||||||
normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
|
except TypeError:
|
||||||
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
|
# TypeError indicates it was not possible to compute a `TypeSpec` for
|
||||||
normalized_components.append(
|
# the value. As a fallback try converting the value to a tensor.
|
||||||
ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
|
||||||
t, name="component_%d" % i))
|
|
||||||
elif isinstance(
|
|
||||||
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
|
|
||||||
normalized_components.append(t)
|
|
||||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
|
||||||
normalized_components.append(t)
|
|
||||||
else:
|
|
||||||
normalized_components.append(
|
normalized_components.append(
|
||||||
ops.convert_to_tensor(t, name="component_%d" % i))
|
ops.convert_to_tensor(t, name="component_%d" % i))
|
||||||
|
else:
|
||||||
|
if isinstance(spec, sparse_tensor.SparseTensorSpec):
|
||||||
|
normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
|
||||||
|
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
|
||||||
|
normalized_components.append(
|
||||||
|
ragged_tensor.convert_to_tensor_or_ragged_tensor(
|
||||||
|
t, name="component_%d" % i))
|
||||||
|
elif isinstance(
|
||||||
|
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
|
||||||
|
normalized_components.append(t)
|
||||||
|
elif isinstance(t, composite_tensor.CompositeTensor):
|
||||||
|
normalized_components.append(t)
|
||||||
|
else:
|
||||||
|
normalized_components.append(
|
||||||
|
ops.convert_to_tensor(t, name="component_%d" % i))
|
||||||
return nest.pack_sequence_as(element, normalized_components)
|
return nest.pack_sequence_as(element, normalized_components)
|
||||||
|
|
||||||
|
|
||||||
@ -392,11 +400,13 @@ def are_compatible(spec1, spec2):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def type_spec_from_value(element):
|
def type_spec_from_value(element, use_fallback=True):
|
||||||
"""Creates a type specification for the given value.
|
"""Creates a type specification for the given value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
element: The element to create the type specification for.
|
element: The element to create the type specification for.
|
||||||
|
use_fallback: Whether to fall back to converting the element to a tensor
|
||||||
|
in order to compute its `TypeSpec`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A nested structure of `TypeSpec`s that represents the type specification
|
A nested structure of `TypeSpec`s that represents the type specification
|
||||||
@ -432,14 +442,16 @@ def type_spec_from_value(element):
|
|||||||
# `element` is not a namedtuple
|
# `element` is not a namedtuple
|
||||||
return tuple([type_spec_from_value(v) for v in element])
|
return tuple([type_spec_from_value(v) for v in element])
|
||||||
|
|
||||||
# Fallback: try converting value to a tensor.
|
if use_fallback:
|
||||||
try:
|
# As a fallback try converting the element to a tensor.
|
||||||
tensor = ops.convert_to_tensor(element)
|
try:
|
||||||
spec = type_spec_from_value(tensor)
|
tensor = ops.convert_to_tensor(element)
|
||||||
if spec is not None:
|
spec = type_spec_from_value(tensor)
|
||||||
return spec
|
if spec is not None:
|
||||||
except (ValueError, TypeError):
|
return spec
|
||||||
pass
|
except (ValueError, TypeError) as e:
|
||||||
|
logging.vlog(
|
||||||
|
3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e))
|
||||||
|
|
||||||
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
||||||
(element, type(element).__name__))
|
(element, type(element).__name__))
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python import pywrap_tensorflow
|
|||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
@ -483,8 +484,9 @@ def type_spec_from_value(value):
|
|||||||
spec = _type_spec_from_value(tensor)
|
spec = _type_spec_from_value(tensor)
|
||||||
if spec is not None:
|
if spec is not None:
|
||||||
return spec
|
return spec
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError) as e:
|
||||||
pass
|
logging.vlog(
|
||||||
|
3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e))
|
||||||
|
|
||||||
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
||||||
(value, type(value).__name__))
|
(value, type(value).__name__))
|
||||||
|
Loading…
Reference in New Issue
Block a user