diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 15eed32fe4b..7766a735fe6 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -41,7 +41,6 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.eager import function as eager_function from tensorflow.python.eager import lift_to_graph -from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import device_spec @@ -1268,7 +1267,8 @@ def is_placeholder(x): try: if keras_tensor.keras_tensors_enabled(): return hasattr(x, '_is_backend_placeholder') - if isinstance(x, composite_tensor.CompositeTensor): + from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top + if tf_utils.is_extension_type(x): flat_components = nest.flatten(x, expand_composites=True) return py_any(is_placeholder(c) for c in flat_components) else: @@ -3881,7 +3881,8 @@ class GraphExecutionFunction(object): # CompositeTensors. E.g., if output_structure contains a SparseTensor, then # this ensures that we return its value as a SparseTensorValue rather than # a SparseTensor. - if isinstance(tensor, composite_tensor.CompositeTensor): + from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top + if tf_utils.is_extension_type(tensor): return self._session.run(tensor) else: return tensor diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 1a2b8c48d20..258dc2f1290 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -74,6 +74,7 @@ py_library( "//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:metrics_utils", "//tensorflow/python/keras/utils:mode_keys", + "//tensorflow/python/keras/utils:tf_utils", "//tensorflow/python/keras/utils:version_utils", "//tensorflow/python/module", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -178,6 +179,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras/utils:engine_utils", + "//tensorflow/python/keras/utils:tf_utils", ], ) diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index e8759b35448..7996cd31ea5 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -40,10 +40,10 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework.ops import composite_tensor from tensorflow.python.keras import backend from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.utils import data_utils +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -527,7 +527,7 @@ class CompositeTensorDataAdapter(DataAdapter): def _is_composite(v): # Dataset inherits from CompositeTensor but shouldn't be handled here. - if (isinstance(v, composite_tensor.CompositeTensor) and + if (tf_utils.is_extension_type(v) and not isinstance(v, dataset_ops.DatasetV2)): return True # Support Scipy sparse tensors if scipy is installed diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index f3911dba9c4..892773fa656 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -27,7 +27,6 @@ import warnings from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context -from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import ops from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer @@ -641,7 +640,7 @@ class Functional(training_lib.Model): # Dtype casting. tensor = math_ops.cast(tensor, dtype=ref_input.dtype) - elif isinstance(tensor, composite_tensor.CompositeTensor): + elif tf_utils.is_extension_type(tensor): # Dtype casting. tensor = math_ops.cast(tensor, dtype=ref_input.dtype) diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index 33f9320e516..f92709a1128 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.distribute import distribution_strategy_context -from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend @@ -183,8 +182,8 @@ class InputLayer(base_layer.Layer): node_module.Node(layer=self, outputs=input_tensor) # Store type spec - if isinstance(input_tensor, ( - composite_tensor.CompositeTensor, keras_tensor.KerasTensor)): + if isinstance(input_tensor, keras_tensor.KerasTensor) or ( + tf_utils.is_extension_type(input_tensor)): self._type_spec = input_tensor._type_spec # pylint: disable=protected-access else: self._type_spec = tensor_spec.TensorSpec( diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 61f81d1c047..77af55ae39b 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.eager import context from tensorflow.python.eager import def_function -from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor_utils from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -57,6 +56,7 @@ from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import losses_utils from tensorflow.python.keras.utils import tf_inspect +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -2378,7 +2378,7 @@ class Model(training_lib.Model): def _type_spec_from_value(value): """Grab type_spec without converting array-likes to tensors.""" - if isinstance(value, composite_tensor.CompositeTensor): + if tf_utils.is_extension_type(value): return value._type_spec # pylint: disable=protected-access # Get a TensorSpec for array-like data without # converting the data to a Tensor diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 3e75da4ec13..a7334bc6132 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -284,6 +284,23 @@ def are_all_symbolic_tensors(tensors): _user_convertible_tensor_types = set() +def is_extension_type(tensor): + """Returns whether a tensor is of an ExtensionType. + + github.com/tensorflow/community/pull/269 + Currently it works by checking if `tensor` is a `CompositeTensor` instance, + but this will be changed to use an appropriate extensiontype protocol + check once ExtensionType is made public. + + Arguments: + tensor: An object to test + + Returns: + True if the tensor is an extension type object, false if not. + """ + return isinstance(tensor, composite_tensor.CompositeTensor) + + def is_symbolic_tensor(tensor): """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor. @@ -298,7 +315,7 @@ def is_symbolic_tensor(tensor): """ if isinstance(tensor, ops.Tensor): return hasattr(tensor, 'graph') - elif isinstance(tensor, composite_tensor.CompositeTensor): + elif is_extension_type(tensor): component_tensors = nest.flatten(tensor, expand_composites=True) return any(hasattr(t, 'graph') for t in component_tensors) elif isinstance(tensor, variables.Variable): @@ -351,7 +368,7 @@ def register_symbolic_tensor_type(cls): def type_spec_from_value(value): """Grab type_spec without converting array-likes to tensors.""" - if isinstance(value, composite_tensor.CompositeTensor): + if is_extension_type(value): return value._type_spec # pylint: disable=protected-access # Get a TensorSpec for array-like data without # converting the data to a Tensor @@ -441,7 +458,7 @@ def get_tensor_spec(t, dynamic_batch=False, name=None): # pylint: disable=protected-access if isinstance(t, type_spec.TypeSpec): spec = t - elif isinstance(t, composite_tensor.CompositeTensor): + elif is_extension_type(t): # TODO(b/148821952): Should these specs have a name attr? spec = t._type_spec elif (hasattr(t, '_keras_history') and diff --git a/tensorflow/python/keras/utils/tf_utils_test.py b/tensorflow/python/keras/utils/tf_utils_test.py index 73d8671e388..f096c61ab3c 100644 --- a/tensorflow/python/keras/utils/tf_utils_test.py +++ b/tensorflow/python/keras/utils/tf_utils_test.py @@ -22,11 +22,14 @@ from absl.testing import parameterized from tensorflow.python import keras from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import combinations from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test @@ -200,5 +203,24 @@ class TestIsRagged(test.TestCase): tensor = [1., 2., 3.] self.assertFalse(tf_utils.is_ragged(tensor)) + +class TestIsExtensionType(test.TestCase): + + def test_is_extension_type_return_true_for_ragged_tensor(self): + self.assertTrue(tf_utils.is_extension_type( + ragged_factory_ops.constant([[1, 2], [3]]))) + + def test_is_extension_type_return_true_for_sparse_tensor(self): + self.assertTrue(tf_utils.is_extension_type( + sparse_ops.from_dense([[1, 2], [3, 4]]))) + + def test_is_extension_type_return_false_for_dense_tensor(self): + self.assertFalse(tf_utils.is_extension_type( + constant_op.constant([[1, 2], [3, 4]]))) + + def test_is_extension_type_return_false_for_list(self): + tensor = [1., 2., 3.] + self.assertFalse(tf_utils.is_extension_type(tensor)) + if __name__ == '__main__': test.main()