Remove deprecated APIs from CompositeTensor
PiperOrigin-RevId: 272745538
This commit is contained in:
parent
c73f9f85fc
commit
84f9d53683
@ -1185,8 +1185,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
|||||||
def testIsGraphTensor(self):
|
def testIsGraphTensor(self):
|
||||||
per_replica = values.PerReplica(values.SingleDeviceMap("CPU"),
|
per_replica = values.PerReplica(values.SingleDeviceMap("CPU"),
|
||||||
(constant_op.constant(1.),))
|
(constant_op.constant(1.),))
|
||||||
self.assertEqual(per_replica._is_graph_tensor,
|
for t in nest.flatten(per_replica, expand_composites=True):
|
||||||
not context.executing_eagerly())
|
self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
|
||||||
|
|
||||||
def testDoesNotTriggerFunctionTracing(self):
|
def testDoesNotTriggerFunctionTracing(self):
|
||||||
traces = []
|
traces = []
|
||||||
@ -1223,9 +1223,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
|||||||
values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
|
values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
|
||||||
y = f(x)
|
y = f(x)
|
||||||
self.assertIsNot(x, y)
|
self.assertIsNot(x, y)
|
||||||
for a, b in zip(x._to_components(), y._to_components()):
|
nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
|
||||||
self.assertAllEqual(a, b)
|
self.assertEqual(x._type_spec, y._type_spec)
|
||||||
self.assertEqual(x._component_metadata(), y._component_metadata())
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testCondWithTensorValues(self):
|
def testCondWithTensorValues(self):
|
||||||
|
@ -22,8 +22,8 @@ import abc
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
|
||||||
from tensorflow.python import _pywrap_utils
|
from tensorflow.python import _pywrap_utils
|
||||||
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
@ -53,46 +53,6 @@ class CompositeTensor(object):
|
|||||||
"""A `TypeSpec` describing the type of this value."""
|
"""A `TypeSpec` describing the type of this value."""
|
||||||
raise NotImplementedError("%s._type_spec()" % type(self).__name__)
|
raise NotImplementedError("%s._type_spec()" % type(self).__name__)
|
||||||
|
|
||||||
# Deprecated -- use self._type_spec._to_components(self) instead.
|
|
||||||
# TODO(b/133606651) Remove all callers and then delete this method.
|
|
||||||
def _to_components(self):
|
|
||||||
"""Decomposes this composite tensor into its component tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A nested structure of `tf.Tensor`s and `CompositeTensor`s that can be
|
|
||||||
used to reconstruct this composite tensor (along with metadata returned
|
|
||||||
by `_component_metadata`).
|
|
||||||
"""
|
|
||||||
return self._type_spec._to_components(self) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
# Deprecated -- use self._type_spec instead.
|
|
||||||
# TODO(b/133606651) Remove all callers and then delete this method.
|
|
||||||
def _component_metadata(self):
|
|
||||||
"""Returns any non-tensor metadata needed to reconstruct a composite tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A nested structure of metadata that can be used to reconstruct this
|
|
||||||
composite tensor (along with the tensors returned by `_to_components`).
|
|
||||||
"""
|
|
||||||
return self._type_spec
|
|
||||||
|
|
||||||
# Deprecated -- use metadata._from_components(components) instead.
|
|
||||||
# TODO(b/133606651) Remove all callers and then delete this method.
|
|
||||||
@staticmethod
|
|
||||||
def _from_components(components, metadata):
|
|
||||||
"""Creates a composite tensor of type `cls` from components.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
components: A nested structure whose values are `tf.Tensor`s or
|
|
||||||
`tf.CompositeTensor`s (as returned by `_to_components`).
|
|
||||||
metadata: A nested structure containing any additional metadata needed to
|
|
||||||
reconstruct the composite tensor (as returned by `_composite_metadata`).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `CompositeTensor` of type `cls`.
|
|
||||||
"""
|
|
||||||
return metadata._from_components(components) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def _shape_invariant_to_type_spec(self, shape):
|
def _shape_invariant_to_type_spec(self, shape):
|
||||||
"""Returns a TypeSpec given a shape invariant (used by `tf.while_loop`).
|
"""Returns a TypeSpec given a shape invariant (used by `tf.while_loop`).
|
||||||
|
|
||||||
@ -111,16 +71,6 @@ class CompositeTensor(object):
|
|||||||
raise NotImplementedError("%s._shape_invariant_to_type_spec"
|
raise NotImplementedError("%s._shape_invariant_to_type_spec"
|
||||||
% type(self).__name__)
|
% type(self).__name__)
|
||||||
|
|
||||||
# TODO(b/133606651) Remove this property, since it's not clear what it should
|
|
||||||
# return if a CompositeTensor has a mix of graph and non-graph components.
|
|
||||||
# Update users to perform an appropraite check themselves.
|
|
||||||
@property
|
|
||||||
def _is_graph_tensor(self):
|
|
||||||
"""Returns True if this tensor's components belong to a TF graph."""
|
|
||||||
components = self._type_spec._to_components(self) # pylint: disable=protected-access
|
|
||||||
tensors = nest.flatten(components, expand_composites=True)
|
|
||||||
return any(hasattr(t, "graph") for t in tensors)
|
|
||||||
|
|
||||||
def _consumers(self):
|
def _consumers(self):
|
||||||
"""Returns a list of `Operation`s that consume this `CompositeTensor`.
|
"""Returns a list of `Operation`s that consume this `CompositeTensor`.
|
||||||
|
|
||||||
@ -132,7 +82,7 @@ class CompositeTensor(object):
|
|||||||
"""
|
"""
|
||||||
consumers = nest.flatten([
|
consumers = nest.flatten([
|
||||||
component.consumers()
|
component.consumers()
|
||||||
for component in self._to_components()
|
for component in nest.flatten(self, expand_composites=True)
|
||||||
if getattr(component, "graph", None) is not None
|
if getattr(component, "graph", None) is not None
|
||||||
])
|
])
|
||||||
return list(set(consumers))
|
return list(set(consumers))
|
||||||
|
@ -342,7 +342,8 @@ def is_symbolic_tensor(tensor):
|
|||||||
return (getattr(tensor, '_keras_history', False) or
|
return (getattr(tensor, '_keras_history', False) or
|
||||||
not context.executing_eagerly())
|
not context.executing_eagerly())
|
||||||
if isinstance(tensor, composite_tensor.CompositeTensor):
|
if isinstance(tensor, composite_tensor.CompositeTensor):
|
||||||
return tensor._is_graph_tensor # pylint: disable=protected-access
|
component_tensors = nest.flatten(tensor, expand_composites=True)
|
||||||
|
return any(hasattr(t, 'graph') for t in component_tensors)
|
||||||
if isinstance(tensor, ops.Tensor):
|
if isinstance(tensor, ops.Tensor):
|
||||||
return hasattr(tensor, 'graph')
|
return hasattr(tensor, 'graph')
|
||||||
return False
|
return False
|
||||||
|
@ -39,8 +39,8 @@ import collections as _collections
|
|||||||
import six as _six
|
import six as _six
|
||||||
|
|
||||||
from tensorflow.python import _pywrap_utils
|
from tensorflow.python import _pywrap_utils
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
|
||||||
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
_SHALLOW_TREE_HAS_INVALID_KEYS = (
|
_SHALLOW_TREE_HAS_INVALID_KEYS = (
|
||||||
@ -197,7 +197,8 @@ def _yield_sorted_items(iterable):
|
|||||||
for field in iterable._fields:
|
for field in iterable._fields:
|
||||||
yield field, getattr(iterable, field)
|
yield field, getattr(iterable, field)
|
||||||
elif _is_composite_tensor(iterable):
|
elif _is_composite_tensor(iterable):
|
||||||
yield type(iterable).__name__, iterable._to_components() # pylint: disable=protected-access
|
type_spec = iterable._type_spec # pylint: disable=protected-access
|
||||||
|
yield type(iterable).__name__, type_spec._to_components(iterable) # pylint: disable=protected-access
|
||||||
elif _is_type_spec(iterable):
|
elif _is_type_spec(iterable):
|
||||||
# Note: to allow CompositeTensors and their TypeSpecs to have matching
|
# Note: to allow CompositeTensors and their TypeSpecs to have matching
|
||||||
# structures, we need to use the same key string here.
|
# structures, we need to use the same key string here.
|
||||||
|
Loading…
Reference in New Issue
Block a user