Remove deprecated APIs from CompositeTensor
PiperOrigin-RevId: 272745538
This commit is contained in:
parent
c73f9f85fc
commit
84f9d53683
@ -1185,8 +1185,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
def testIsGraphTensor(self):
|
||||
per_replica = values.PerReplica(values.SingleDeviceMap("CPU"),
|
||||
(constant_op.constant(1.),))
|
||||
self.assertEqual(per_replica._is_graph_tensor,
|
||||
not context.executing_eagerly())
|
||||
for t in nest.flatten(per_replica, expand_composites=True):
|
||||
self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
|
||||
|
||||
def testDoesNotTriggerFunctionTracing(self):
|
||||
traces = []
|
||||
@ -1223,9 +1223,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
|
||||
y = f(x)
|
||||
self.assertIsNot(x, y)
|
||||
for a, b in zip(x._to_components(), y._to_components()):
|
||||
self.assertAllEqual(a, b)
|
||||
self.assertEqual(x._component_metadata(), y._component_metadata())
|
||||
nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
|
||||
self.assertEqual(x._type_spec, y._type_spec)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCondWithTensorValues(self):
|
||||
|
@ -22,8 +22,8 @@ import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -53,46 +53,6 @@ class CompositeTensor(object):
|
||||
"""A `TypeSpec` describing the type of this value."""
|
||||
raise NotImplementedError("%s._type_spec()" % type(self).__name__)
|
||||
|
||||
# Deprecated -- use self._type_spec._to_components(self) instead.
|
||||
# TODO(b/133606651) Remove all callers and then delete this method.
|
||||
def _to_components(self):
|
||||
"""Decomposes this composite tensor into its component tensors.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.Tensor`s and `CompositeTensor`s that can be
|
||||
used to reconstruct this composite tensor (along with metadata returned
|
||||
by `_component_metadata`).
|
||||
"""
|
||||
return self._type_spec._to_components(self) # pylint: disable=protected-access
|
||||
|
||||
# Deprecated -- use self._type_spec instead.
|
||||
# TODO(b/133606651) Remove all callers and then delete this method.
|
||||
def _component_metadata(self):
|
||||
"""Returns any non-tensor metadata needed to reconstruct a composite tensor.
|
||||
|
||||
Returns:
|
||||
A nested structure of metadata that can be used to reconstruct this
|
||||
composite tensor (along with the tensors returned by `_to_components`).
|
||||
"""
|
||||
return self._type_spec
|
||||
|
||||
# Deprecated -- use metadata._from_components(components) instead.
|
||||
# TODO(b/133606651) Remove all callers and then delete this method.
|
||||
@staticmethod
|
||||
def _from_components(components, metadata):
|
||||
"""Creates a composite tensor of type `cls` from components.
|
||||
|
||||
Args:
|
||||
components: A nested structure whose values are `tf.Tensor`s or
|
||||
`tf.CompositeTensor`s (as returned by `_to_components`).
|
||||
metadata: A nested structure containing any additional metadata needed to
|
||||
reconstruct the composite tensor (as returned by `_composite_metadata`).
|
||||
|
||||
Returns:
|
||||
A `CompositeTensor` of type `cls`.
|
||||
"""
|
||||
return metadata._from_components(components) # pylint: disable=protected-access
|
||||
|
||||
def _shape_invariant_to_type_spec(self, shape):
|
||||
"""Returns a TypeSpec given a shape invariant (used by `tf.while_loop`).
|
||||
|
||||
@ -111,16 +71,6 @@ class CompositeTensor(object):
|
||||
raise NotImplementedError("%s._shape_invariant_to_type_spec"
|
||||
% type(self).__name__)
|
||||
|
||||
# TODO(b/133606651) Remove this property, since it's not clear what it should
|
||||
# return if a CompositeTensor has a mix of graph and non-graph components.
|
||||
# Update users to perform an appropraite check themselves.
|
||||
@property
|
||||
def _is_graph_tensor(self):
|
||||
"""Returns True if this tensor's components belong to a TF graph."""
|
||||
components = self._type_spec._to_components(self) # pylint: disable=protected-access
|
||||
tensors = nest.flatten(components, expand_composites=True)
|
||||
return any(hasattr(t, "graph") for t in tensors)
|
||||
|
||||
def _consumers(self):
|
||||
"""Returns a list of `Operation`s that consume this `CompositeTensor`.
|
||||
|
||||
@ -132,7 +82,7 @@ class CompositeTensor(object):
|
||||
"""
|
||||
consumers = nest.flatten([
|
||||
component.consumers()
|
||||
for component in self._to_components()
|
||||
for component in nest.flatten(self, expand_composites=True)
|
||||
if getattr(component, "graph", None) is not None
|
||||
])
|
||||
return list(set(consumers))
|
||||
|
@ -342,7 +342,8 @@ def is_symbolic_tensor(tensor):
|
||||
return (getattr(tensor, '_keras_history', False) or
|
||||
not context.executing_eagerly())
|
||||
if isinstance(tensor, composite_tensor.CompositeTensor):
|
||||
return tensor._is_graph_tensor # pylint: disable=protected-access
|
||||
component_tensors = nest.flatten(tensor, expand_composites=True)
|
||||
return any(hasattr(t, 'graph') for t in component_tensors)
|
||||
if isinstance(tensor, ops.Tensor):
|
||||
return hasattr(tensor, 'graph')
|
||||
return False
|
||||
|
@ -39,8 +39,8 @@ import collections as _collections
|
||||
import six as _six
|
||||
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.python.util.compat import collections_abc as _collections_abc
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
_SHALLOW_TREE_HAS_INVALID_KEYS = (
|
||||
@ -197,7 +197,8 @@ def _yield_sorted_items(iterable):
|
||||
for field in iterable._fields:
|
||||
yield field, getattr(iterable, field)
|
||||
elif _is_composite_tensor(iterable):
|
||||
yield type(iterable).__name__, iterable._to_components() # pylint: disable=protected-access
|
||||
type_spec = iterable._type_spec # pylint: disable=protected-access
|
||||
yield type(iterable).__name__, type_spec._to_components(iterable) # pylint: disable=protected-access
|
||||
elif _is_type_spec(iterable):
|
||||
# Note: to allow CompositeTensors and their TypeSpecs to have matching
|
||||
# structures, we need to use the same key string here.
|
||||
|
Loading…
Reference in New Issue
Block a user