From 84f9d53683484cefd6c87bb4655f3d658a02171a Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Thu, 3 Oct 2019 14:44:09 -0700 Subject: [PATCH] Remove deprecated APIs from CompositeTensor PiperOrigin-RevId: 272745538 --- tensorflow/python/distribute/values_test.py | 9 ++-- .../python/framework/composite_tensor.py | 54 +------------------ tensorflow/python/keras/utils/tf_utils.py | 3 +- tensorflow/python/util/nest.py | 5 +- 4 files changed, 11 insertions(+), 60 deletions(-) diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index cb6cdf975ed..6c24063101d 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -1185,8 +1185,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase): def testIsGraphTensor(self): per_replica = values.PerReplica(values.SingleDeviceMap("CPU"), (constant_op.constant(1.),)) - self.assertEqual(per_replica._is_graph_tensor, - not context.executing_eagerly()) + for t in nest.flatten(per_replica, expand_composites=True): + self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly()) def testDoesNotTriggerFunctionTracing(self): traces = [] @@ -1223,9 +1223,8 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase): values.SingleDeviceMap("CPU"), (constant_op.constant(1.),)) y = f(x) self.assertIsNot(x, y) - for a, b in zip(x._to_components(), y._to_components()): - self.assertAllEqual(a, b) - self.assertEqual(x._component_metadata(), y._component_metadata()) + nest.map_structure(self.assertAllEqual, x, y, expand_composites=True) + self.assertEqual(x._type_spec, y._type_spec) @test_util.run_in_graph_and_eager_modes def testCondWithTensorValues(self): diff --git a/tensorflow/python/framework/composite_tensor.py b/tensorflow/python/framework/composite_tensor.py index 512fff92558..b7a4d65b412 100644 --- a/tensorflow/python/framework/composite_tensor.py +++ b/tensorflow/python/framework/composite_tensor.py @@ -22,8 +22,8 @@ import abc import six -from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python import _pywrap_utils +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python.util import nest @@ -53,46 +53,6 @@ class CompositeTensor(object): """A `TypeSpec` describing the type of this value.""" raise NotImplementedError("%s._type_spec()" % type(self).__name__) - # Deprecated -- use self._type_spec._to_components(self) instead. - # TODO(b/133606651) Remove all callers and then delete this method. - def _to_components(self): - """Decomposes this composite tensor into its component tensors. - - Returns: - A nested structure of `tf.Tensor`s and `CompositeTensor`s that can be - used to reconstruct this composite tensor (along with metadata returned - by `_component_metadata`). - """ - return self._type_spec._to_components(self) # pylint: disable=protected-access - - # Deprecated -- use self._type_spec instead. - # TODO(b/133606651) Remove all callers and then delete this method. - def _component_metadata(self): - """Returns any non-tensor metadata needed to reconstruct a composite tensor. - - Returns: - A nested structure of metadata that can be used to reconstruct this - composite tensor (along with the tensors returned by `_to_components`). - """ - return self._type_spec - - # Deprecated -- use metadata._from_components(components) instead. - # TODO(b/133606651) Remove all callers and then delete this method. - @staticmethod - def _from_components(components, metadata): - """Creates a composite tensor of type `cls` from components. - - Args: - components: A nested structure whose values are `tf.Tensor`s or - `tf.CompositeTensor`s (as returned by `_to_components`). - metadata: A nested structure containing any additional metadata needed to - reconstruct the composite tensor (as returned by `_composite_metadata`). - - Returns: - A `CompositeTensor` of type `cls`. - """ - return metadata._from_components(components) # pylint: disable=protected-access - def _shape_invariant_to_type_spec(self, shape): """Returns a TypeSpec given a shape invariant (used by `tf.while_loop`). @@ -111,16 +71,6 @@ class CompositeTensor(object): raise NotImplementedError("%s._shape_invariant_to_type_spec" % type(self).__name__) - # TODO(b/133606651) Remove this property, since it's not clear what it should - # return if a CompositeTensor has a mix of graph and non-graph components. - # Update users to perform an appropraite check themselves. - @property - def _is_graph_tensor(self): - """Returns True if this tensor's components belong to a TF graph.""" - components = self._type_spec._to_components(self) # pylint: disable=protected-access - tensors = nest.flatten(components, expand_composites=True) - return any(hasattr(t, "graph") for t in tensors) - def _consumers(self): """Returns a list of `Operation`s that consume this `CompositeTensor`. @@ -132,7 +82,7 @@ class CompositeTensor(object): """ consumers = nest.flatten([ component.consumers() - for component in self._to_components() + for component in nest.flatten(self, expand_composites=True) if getattr(component, "graph", None) is not None ]) return list(set(consumers)) diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 24da4add22c..cec7497851f 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -342,7 +342,8 @@ def is_symbolic_tensor(tensor): return (getattr(tensor, '_keras_history', False) or not context.executing_eagerly()) if isinstance(tensor, composite_tensor.CompositeTensor): - return tensor._is_graph_tensor # pylint: disable=protected-access + component_tensors = nest.flatten(tensor, expand_composites=True) + return any(hasattr(t, 'graph') for t in component_tensors) if isinstance(tensor, ops.Tensor): return hasattr(tensor, 'graph') return False diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 2284b6cc9c8..a4466537edf 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -39,8 +39,8 @@ import collections as _collections import six as _six from tensorflow.python import _pywrap_utils -from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.compat import collections_abc as _collections_abc +from tensorflow.python.util.tf_export import tf_export _SHALLOW_TREE_HAS_INVALID_KEYS = ( @@ -197,7 +197,8 @@ def _yield_sorted_items(iterable): for field in iterable._fields: yield field, getattr(iterable, field) elif _is_composite_tensor(iterable): - yield type(iterable).__name__, iterable._to_components() # pylint: disable=protected-access + type_spec = iterable._type_spec # pylint: disable=protected-access + yield type(iterable).__name__, type_spec._to_components(iterable) # pylint: disable=protected-access elif _is_type_spec(iterable): # Note: to allow CompositeTensors and their TypeSpecs to have matching # structures, we need to use the same key string here.