Remove deprecated APIs from CompositeTensor

PiperOrigin-RevId: 272745538
This commit is contained in:
Gaurav Jain 2019-10-03 14:44:09 -07:00 committed by TensorFlower Gardener
parent c73f9f85fc
commit 84f9d53683
4 changed files with 11 additions and 60 deletions
tensorflow/python
distribute
framework
keras/utils
util

View File

@ -1185,8 +1185,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
def testIsGraphTensor(self):
per_replica = values.PerReplica(values.SingleDeviceMap("CPU"),
(constant_op.constant(1.),))
self.assertEqual(per_replica._is_graph_tensor,
not context.executing_eagerly())
for t in nest.flatten(per_replica, expand_composites=True):
self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
def testDoesNotTriggerFunctionTracing(self):
traces = []
@ -1223,9 +1223,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
values.SingleDeviceMap("CPU"), (constant_op.constant(1.),))
y = f(x)
self.assertIsNot(x, y)
for a, b in zip(x._to_components(), y._to_components()):
self.assertAllEqual(a, b)
self.assertEqual(x._component_metadata(), y._component_metadata())
nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
self.assertEqual(x._type_spec, y._type_spec)
@test_util.run_in_graph_and_eager_modes
def testCondWithTensorValues(self):

View File

@ -22,8 +22,8 @@ import abc
import six
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python import _pywrap_utils
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.util import nest
@ -53,46 +53,6 @@ class CompositeTensor(object):
"""A `TypeSpec` describing the type of this value."""
raise NotImplementedError("%s._type_spec()" % type(self).__name__)
# Deprecated -- use self._type_spec._to_components(self) instead.
# TODO(b/133606651) Remove all callers and then delete this method.
def _to_components(self):
"""Decomposes this composite tensor into its component tensors.
Returns:
A nested structure of `tf.Tensor`s and `CompositeTensor`s that can be
used to reconstruct this composite tensor (along with metadata returned
by `_component_metadata`).
"""
return self._type_spec._to_components(self) # pylint: disable=protected-access
# Deprecated -- use self._type_spec instead.
# TODO(b/133606651) Remove all callers and then delete this method.
def _component_metadata(self):
"""Returns any non-tensor metadata needed to reconstruct a composite tensor.
Returns:
A nested structure of metadata that can be used to reconstruct this
composite tensor (along with the tensors returned by `_to_components`).
"""
return self._type_spec
# Deprecated -- use metadata._from_components(components) instead.
# TODO(b/133606651) Remove all callers and then delete this method.
@staticmethod
def _from_components(components, metadata):
"""Creates a composite tensor of type `cls` from components.
Args:
components: A nested structure whose values are `tf.Tensor`s or
`tf.CompositeTensor`s (as returned by `_to_components`).
metadata: A nested structure containing any additional metadata needed to
reconstruct the composite tensor (as returned by `_composite_metadata`).
Returns:
A `CompositeTensor` of type `cls`.
"""
return metadata._from_components(components) # pylint: disable=protected-access
def _shape_invariant_to_type_spec(self, shape):
"""Returns a TypeSpec given a shape invariant (used by `tf.while_loop`).
@ -111,16 +71,6 @@ class CompositeTensor(object):
raise NotImplementedError("%s._shape_invariant_to_type_spec"
% type(self).__name__)
# TODO(b/133606651) Remove this property, since it's not clear what it should
# return if a CompositeTensor has a mix of graph and non-graph components.
# Update users to perform an appropraite check themselves.
@property
def _is_graph_tensor(self):
"""Returns True if this tensor's components belong to a TF graph."""
components = self._type_spec._to_components(self) # pylint: disable=protected-access
tensors = nest.flatten(components, expand_composites=True)
return any(hasattr(t, "graph") for t in tensors)
def _consumers(self):
"""Returns a list of `Operation`s that consume this `CompositeTensor`.
@ -132,7 +82,7 @@ class CompositeTensor(object):
"""
consumers = nest.flatten([
component.consumers()
for component in self._to_components()
for component in nest.flatten(self, expand_composites=True)
if getattr(component, "graph", None) is not None
])
return list(set(consumers))

View File

@ -342,7 +342,8 @@ def is_symbolic_tensor(tensor):
return (getattr(tensor, '_keras_history', False) or
not context.executing_eagerly())
if isinstance(tensor, composite_tensor.CompositeTensor):
return tensor._is_graph_tensor # pylint: disable=protected-access
component_tensors = nest.flatten(tensor, expand_composites=True)
return any(hasattr(t, 'graph') for t in component_tensors)
if isinstance(tensor, ops.Tensor):
return hasattr(tensor, 'graph')
return False

View File

@ -39,8 +39,8 @@ import collections as _collections
import six as _six
from tensorflow.python import _pywrap_utils
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.util.compat import collections_abc as _collections_abc
from tensorflow.python.util.tf_export import tf_export
_SHALLOW_TREE_HAS_INVALID_KEYS = (
@ -197,7 +197,8 @@ def _yield_sorted_items(iterable):
for field in iterable._fields:
yield field, getattr(iterable, field)
elif _is_composite_tensor(iterable):
yield type(iterable).__name__, iterable._to_components() # pylint: disable=protected-access
type_spec = iterable._type_spec # pylint: disable=protected-access
yield type(iterable).__name__, type_spec._to_components(iterable) # pylint: disable=protected-access
elif _is_type_spec(iterable):
# Note: to allow CompositeTensors and their TypeSpecs to have matching
# structures, we need to use the same key string here.