Switch all CompositeTensor instance checks in Keras to use a centralized tf_utils.is_extension_type util. This util will use the public ExtensionType api once it is in place.

PiperOrigin-RevId: 332971935
Change-Id: Ic73743d70b2e11e431262d209ac1fd8666570309
This commit is contained in:
Tomer Kaftan 2020-09-21 17:23:58 -07:00 committed by TensorFlower Gardener
parent e90dd8abe7
commit 5d5534edf7
8 changed files with 55 additions and 15 deletions

View File

@ -41,7 +41,6 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.eager import function as eager_function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device_spec
@ -1268,7 +1267,8 @@ def is_placeholder(x):
try:
if keras_tensor.keras_tensors_enabled():
return hasattr(x, '_is_backend_placeholder')
if isinstance(x, composite_tensor.CompositeTensor):
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
if tf_utils.is_extension_type(x):
flat_components = nest.flatten(x, expand_composites=True)
return py_any(is_placeholder(c) for c in flat_components)
else:
@ -3881,7 +3881,8 @@ class GraphExecutionFunction(object):
# CompositeTensors. E.g., if output_structure contains a SparseTensor, then
# this ensures that we return its value as a SparseTensorValue rather than
# a SparseTensor.
if isinstance(tensor, composite_tensor.CompositeTensor):
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
if tf_utils.is_extension_type(tensor):
return self._session.run(tensor)
else:
return tensor

View File

@ -74,6 +74,7 @@ py_library(
"//tensorflow/python/keras/utils:engine_utils",
"//tensorflow/python/keras/utils:metrics_utils",
"//tensorflow/python/keras/utils:mode_keys",
"//tensorflow/python/keras/utils:tf_utils",
"//tensorflow/python/keras/utils:version_utils",
"//tensorflow/python/module",
"//tensorflow/python/ops/ragged:ragged_tensor",
@ -178,6 +179,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/keras/utils:engine_utils",
"//tensorflow/python/keras/utils:tf_utils",
],
)

View File

@ -40,10 +40,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework.ops import composite_tensor
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@ -527,7 +527,7 @@ class CompositeTensorDataAdapter(DataAdapter):
def _is_composite(v):
# Dataset inherits from CompositeTensor but shouldn't be handled here.
if (isinstance(v, composite_tensor.CompositeTensor) and
if (tf_utils.is_extension_type(v) and
not isinstance(v, dataset_ops.DatasetV2)):
return True
# Support Scipy sparse tensors if scipy is installed

View File

@ -27,7 +27,6 @@ import warnings
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
@ -641,7 +640,7 @@ class Functional(training_lib.Model):
# Dtype casting.
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
elif isinstance(tensor, composite_tensor.CompositeTensor):
elif tf_utils.is_extension_type(tensor):
# Dtype casting.
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)

View File

@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend
@ -183,8 +182,8 @@ class InputLayer(base_layer.Layer):
node_module.Node(layer=self, outputs=input_tensor)
# Store type spec
if isinstance(input_tensor, (
composite_tensor.CompositeTensor, keras_tensor.KerasTensor)):
if isinstance(input_tensor, keras_tensor.KerasTensor) or (
tf_utils.is_extension_type(input_tensor)):
self._type_spec = input_tensor._type_spec # pylint: disable=protected-access
else:
self._type_spec = tensor_spec.TensorSpec(

View File

@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import composite_tensor_utils
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
@ -57,6 +56,7 @@ from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -2378,7 +2378,7 @@ class Model(training_lib.Model):
def _type_spec_from_value(value):
"""Grab type_spec without converting array-likes to tensors."""
if isinstance(value, composite_tensor.CompositeTensor):
if tf_utils.is_extension_type(value):
return value._type_spec # pylint: disable=protected-access
# Get a TensorSpec for array-like data without
# converting the data to a Tensor

View File

@ -284,6 +284,23 @@ def are_all_symbolic_tensors(tensors):
_user_convertible_tensor_types = set()
def is_extension_type(tensor):
"""Returns whether a tensor is of an ExtensionType.
github.com/tensorflow/community/pull/269
Currently it works by checking if `tensor` is a `CompositeTensor` instance,
but this will be changed to use an appropriate extensiontype protocol
check once ExtensionType is made public.
Arguments:
tensor: An object to test
Returns:
True if the tensor is an extension type object, false if not.
"""
return isinstance(tensor, composite_tensor.CompositeTensor)
def is_symbolic_tensor(tensor):
"""Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
@ -298,7 +315,7 @@ def is_symbolic_tensor(tensor):
"""
if isinstance(tensor, ops.Tensor):
return hasattr(tensor, 'graph')
elif isinstance(tensor, composite_tensor.CompositeTensor):
elif is_extension_type(tensor):
component_tensors = nest.flatten(tensor, expand_composites=True)
return any(hasattr(t, 'graph') for t in component_tensors)
elif isinstance(tensor, variables.Variable):
@ -351,7 +368,7 @@ def register_symbolic_tensor_type(cls):
def type_spec_from_value(value):
"""Grab type_spec without converting array-likes to tensors."""
if isinstance(value, composite_tensor.CompositeTensor):
if is_extension_type(value):
return value._type_spec # pylint: disable=protected-access
# Get a TensorSpec for array-like data without
# converting the data to a Tensor
@ -441,7 +458,7 @@ def get_tensor_spec(t, dynamic_batch=False, name=None):
# pylint: disable=protected-access
if isinstance(t, type_spec.TypeSpec):
spec = t
elif isinstance(t, composite_tensor.CompositeTensor):
elif is_extension_type(t):
# TODO(b/148821952): Should these specs have a name attr?
spec = t._type_spec
elif (hasattr(t, '_keras_history') and

View File

@ -22,11 +22,14 @@ from absl.testing import parameterized
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import combinations
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
@ -200,5 +203,24 @@ class TestIsRagged(test.TestCase):
tensor = [1., 2., 3.]
self.assertFalse(tf_utils.is_ragged(tensor))
class TestIsExtensionType(test.TestCase):
def test_is_extension_type_return_true_for_ragged_tensor(self):
self.assertTrue(tf_utils.is_extension_type(
ragged_factory_ops.constant([[1, 2], [3]])))
def test_is_extension_type_return_true_for_sparse_tensor(self):
self.assertTrue(tf_utils.is_extension_type(
sparse_ops.from_dense([[1, 2], [3, 4]])))
def test_is_extension_type_return_false_for_dense_tensor(self):
self.assertFalse(tf_utils.is_extension_type(
constant_op.constant([[1, 2], [3, 4]])))
def test_is_extension_type_return_false_for_list(self):
tensor = [1., 2., 3.]
self.assertFalse(tf_utils.is_extension_type(tensor))
if __name__ == '__main__':
test.main()