From 5d5534edf7d0b73cb23f7069135d674e1d27250b Mon Sep 17 00:00:00 2001
From: Tomer Kaftan <kaftan@google.com>
Date: Mon, 21 Sep 2020 17:23:58 -0700
Subject: [PATCH] Switch all CompositeTensor instance checks in Keras to use a
 centralized `tf_utils.is_extension_type` util. This util will use the public
 ExtensionType api once it is in place.

PiperOrigin-RevId: 332971935
Change-Id: Ic73743d70b2e11e431262d209ac1fd8666570309
---
 tensorflow/python/keras/backend.py            |  7 +++---
 tensorflow/python/keras/engine/BUILD          |  2 ++
 .../python/keras/engine/data_adapter.py       |  4 ++--
 tensorflow/python/keras/engine/functional.py  |  3 +--
 tensorflow/python/keras/engine/input_layer.py |  5 ++--
 tensorflow/python/keras/engine/training_v1.py |  4 ++--
 tensorflow/python/keras/utils/tf_utils.py     | 23 ++++++++++++++++---
 .../python/keras/utils/tf_utils_test.py       | 22 ++++++++++++++++++
 8 files changed, 55 insertions(+), 15 deletions(-)

diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 15eed32fe4b..7766a735fe6 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -41,7 +41,6 @@ from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function as eager_function
 from tensorflow.python.eager import lift_to_graph
-from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import config
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import device_spec
@@ -1268,7 +1267,8 @@ def is_placeholder(x):
   try:
     if keras_tensor.keras_tensors_enabled():
       return hasattr(x, '_is_backend_placeholder')
-    if isinstance(x, composite_tensor.CompositeTensor):
+    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
+    if tf_utils.is_extension_type(x):
       flat_components = nest.flatten(x, expand_composites=True)
       return py_any(is_placeholder(c) for c in flat_components)
     else:
@@ -3881,7 +3881,8 @@ class GraphExecutionFunction(object):
     # CompositeTensors. E.g., if output_structure contains a SparseTensor, then
     # this ensures that we return its value as a SparseTensorValue rather than
     # a SparseTensor.
-    if isinstance(tensor, composite_tensor.CompositeTensor):
+    from tensorflow.python.keras.utils import tf_utils  # pylint: disable=g-import-not-at-top
+    if tf_utils.is_extension_type(tensor):
       return self._session.run(tensor)
     else:
       return tensor
diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD
index 1a2b8c48d20..258dc2f1290 100644
--- a/tensorflow/python/keras/engine/BUILD
+++ b/tensorflow/python/keras/engine/BUILD
@@ -74,6 +74,7 @@ py_library(
         "//tensorflow/python/keras/utils:engine_utils",
         "//tensorflow/python/keras/utils:metrics_utils",
         "//tensorflow/python/keras/utils:mode_keys",
+        "//tensorflow/python/keras/utils:tf_utils",
         "//tensorflow/python/keras/utils:version_utils",
         "//tensorflow/python/module",
         "//tensorflow/python/ops/ragged:ragged_tensor",
@@ -178,6 +179,7 @@ py_library(
         "//tensorflow/python:util",
         "//tensorflow/python/data/ops:dataset_ops",
         "//tensorflow/python/keras/utils:engine_utils",
+        "//tensorflow/python/keras/utils:tf_utils",
     ],
 )
 
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index e8759b35448..7996cd31ea5 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -40,10 +40,10 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import smart_cond
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework.ops import composite_tensor
 from tensorflow.python.keras import backend
 from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.utils import data_utils
+from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
@@ -527,7 +527,7 @@ class CompositeTensorDataAdapter(DataAdapter):
 
     def _is_composite(v):
       # Dataset inherits from CompositeTensor but shouldn't be handled here.
-      if (isinstance(v, composite_tensor.CompositeTensor) and
+      if (tf_utils.is_extension_type(v) and
           not isinstance(v, dataset_ops.DatasetV2)):
         return True
       # Support Scipy sparse tensors if scipy is installed
diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py
index f3911dba9c4..892773fa656 100644
--- a/tensorflow/python/keras/engine/functional.py
+++ b/tensorflow/python/keras/engine/functional.py
@@ -27,7 +27,6 @@ import warnings
 from six.moves import zip  # pylint: disable=redefined-builtin
 
 from tensorflow.python.eager import context
-from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import ops
 from tensorflow.python.keras import backend
 from tensorflow.python.keras.engine import base_layer
@@ -641,7 +640,7 @@ class Functional(training_lib.Model):
 
       # Dtype casting.
       tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
-    elif isinstance(tensor, composite_tensor.CompositeTensor):
+    elif tf_utils.is_extension_type(tensor):
       # Dtype casting.
       tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
 
diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py
index 33f9320e516..f92709a1128 100644
--- a/tensorflow/python/keras/engine/input_layer.py
+++ b/tensorflow/python/keras/engine/input_layer.py
@@ -20,7 +20,6 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.python.distribute import distribution_strategy_context
-from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.keras import backend
@@ -183,8 +182,8 @@ class InputLayer(base_layer.Layer):
     node_module.Node(layer=self, outputs=input_tensor)
 
     # Store type spec
-    if isinstance(input_tensor, (
-        composite_tensor.CompositeTensor, keras_tensor.KerasTensor)):
+    if isinstance(input_tensor, keras_tensor.KerasTensor) or (
+        tf_utils.is_extension_type(input_tensor)):
       self._type_spec = input_tensor._type_spec  # pylint: disable=protected-access
     else:
       self._type_spec = tensor_spec.TensorSpec(
diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py
index 61f81d1c047..77af55ae39b 100644
--- a/tensorflow/python/keras/engine/training_v1.py
+++ b/tensorflow/python/keras/engine/training_v1.py
@@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import parameter_server_strategy
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
-from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import composite_tensor_utils
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
@@ -57,6 +56,7 @@ from tensorflow.python.keras.utils import data_utils
 from tensorflow.python.keras.utils import layer_utils
 from tensorflow.python.keras.utils import losses_utils
 from tensorflow.python.keras.utils import tf_inspect
+from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
@@ -2378,7 +2378,7 @@ class Model(training_lib.Model):
 
       def _type_spec_from_value(value):
         """Grab type_spec without converting array-likes to tensors."""
-        if isinstance(value, composite_tensor.CompositeTensor):
+        if tf_utils.is_extension_type(value):
           return value._type_spec  # pylint: disable=protected-access
         # Get a TensorSpec for array-like data without
         # converting the data to a Tensor
diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py
index 3e75da4ec13..a7334bc6132 100644
--- a/tensorflow/python/keras/utils/tf_utils.py
+++ b/tensorflow/python/keras/utils/tf_utils.py
@@ -284,6 +284,23 @@ def are_all_symbolic_tensors(tensors):
 _user_convertible_tensor_types = set()
 
 
+def is_extension_type(tensor):
+  """Returns whether a tensor is of an ExtensionType.
+
+  github.com/tensorflow/community/pull/269
+  Currently it works by checking if `tensor` is a `CompositeTensor` instance,
+  but this will be changed to use an appropriate extensiontype protocol
+  check once ExtensionType is made public.
+
+  Arguments:
+    tensor: An object to test
+
+  Returns:
+    True if the tensor is an extension type object, false if not.
+  """
+  return isinstance(tensor, composite_tensor.CompositeTensor)
+
+
 def is_symbolic_tensor(tensor):
   """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
 
@@ -298,7 +315,7 @@ def is_symbolic_tensor(tensor):
   """
   if isinstance(tensor, ops.Tensor):
     return hasattr(tensor, 'graph')
-  elif isinstance(tensor, composite_tensor.CompositeTensor):
+  elif is_extension_type(tensor):
     component_tensors = nest.flatten(tensor, expand_composites=True)
     return any(hasattr(t, 'graph') for t in component_tensors)
   elif isinstance(tensor, variables.Variable):
@@ -351,7 +368,7 @@ def register_symbolic_tensor_type(cls):
 
 def type_spec_from_value(value):
   """Grab type_spec without converting array-likes to tensors."""
-  if isinstance(value, composite_tensor.CompositeTensor):
+  if is_extension_type(value):
     return value._type_spec  # pylint: disable=protected-access
   # Get a TensorSpec for array-like data without
   # converting the data to a Tensor
@@ -441,7 +458,7 @@ def get_tensor_spec(t, dynamic_batch=False, name=None):
   # pylint: disable=protected-access
   if isinstance(t, type_spec.TypeSpec):
     spec = t
-  elif isinstance(t, composite_tensor.CompositeTensor):
+  elif is_extension_type(t):
     # TODO(b/148821952): Should these specs have a name attr?
     spec = t._type_spec
   elif (hasattr(t, '_keras_history') and
diff --git a/tensorflow/python/keras/utils/tf_utils_test.py b/tensorflow/python/keras/utils/tf_utils_test.py
index 73d8671e388..f096c61ab3c 100644
--- a/tensorflow/python/keras/utils/tf_utils_test.py
+++ b/tensorflow/python/keras/utils/tf_utils_test.py
@@ -22,11 +22,14 @@ from absl.testing import parameterized
 
 from tensorflow.python import keras
 from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.keras import combinations
 from tensorflow.python.keras.utils import tf_utils
+from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import variables
+from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import test
 
@@ -200,5 +203,24 @@ class TestIsRagged(test.TestCase):
     tensor = [1., 2., 3.]
     self.assertFalse(tf_utils.is_ragged(tensor))
 
+
+class TestIsExtensionType(test.TestCase):
+
+  def test_is_extension_type_return_true_for_ragged_tensor(self):
+    self.assertTrue(tf_utils.is_extension_type(
+        ragged_factory_ops.constant([[1, 2], [3]])))
+
+  def test_is_extension_type_return_true_for_sparse_tensor(self):
+    self.assertTrue(tf_utils.is_extension_type(
+        sparse_ops.from_dense([[1, 2], [3, 4]])))
+
+  def test_is_extension_type_return_false_for_dense_tensor(self):
+    self.assertFalse(tf_utils.is_extension_type(
+        constant_op.constant([[1, 2], [3, 4]])))
+
+  def test_is_extension_type_return_false_for_list(self):
+    tensor = [1., 2., 3.]
+    self.assertFalse(tf_utils.is_extension_type(tensor))
+
 if __name__ == '__main__':
   test.main()