Switch all CompositeTensor instance checks in Keras to use a centralized tf_utils.is_extension_type
util. This util will use the public ExtensionType api once it is in place.
PiperOrigin-RevId: 332971935 Change-Id: Ic73743d70b2e11e431262d209ac1fd8666570309
This commit is contained in:
parent
e90dd8abe7
commit
5d5534edf7
tensorflow/python/keras
@ -41,7 +41,6 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function as eager_function
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device_spec
|
||||
@ -1268,7 +1267,8 @@ def is_placeholder(x):
|
||||
try:
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
return hasattr(x, '_is_backend_placeholder')
|
||||
if isinstance(x, composite_tensor.CompositeTensor):
|
||||
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
|
||||
if tf_utils.is_extension_type(x):
|
||||
flat_components = nest.flatten(x, expand_composites=True)
|
||||
return py_any(is_placeholder(c) for c in flat_components)
|
||||
else:
|
||||
@ -3881,7 +3881,8 @@ class GraphExecutionFunction(object):
|
||||
# CompositeTensors. E.g., if output_structure contains a SparseTensor, then
|
||||
# this ensures that we return its value as a SparseTensorValue rather than
|
||||
# a SparseTensor.
|
||||
if isinstance(tensor, composite_tensor.CompositeTensor):
|
||||
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
|
||||
if tf_utils.is_extension_type(tensor):
|
||||
return self._session.run(tensor)
|
||||
else:
|
||||
return tensor
|
||||
|
@ -74,6 +74,7 @@ py_library(
|
||||
"//tensorflow/python/keras/utils:engine_utils",
|
||||
"//tensorflow/python/keras/utils:metrics_utils",
|
||||
"//tensorflow/python/keras/utils:mode_keys",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
"//tensorflow/python/keras/utils:version_utils",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
@ -178,6 +179,7 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/keras/utils:engine_utils",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -40,10 +40,10 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework.ops import composite_tensor
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
@ -527,7 +527,7 @@ class CompositeTensorDataAdapter(DataAdapter):
|
||||
|
||||
def _is_composite(v):
|
||||
# Dataset inherits from CompositeTensor but shouldn't be handled here.
|
||||
if (isinstance(v, composite_tensor.CompositeTensor) and
|
||||
if (tf_utils.is_extension_type(v) and
|
||||
not isinstance(v, dataset_ops.DatasetV2)):
|
||||
return True
|
||||
# Support Scipy sparse tensors if scipy is installed
|
||||
|
@ -27,7 +27,6 @@ import warnings
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
@ -641,7 +640,7 @@ class Functional(training_lib.Model):
|
||||
|
||||
# Dtype casting.
|
||||
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
||||
elif isinstance(tensor, composite_tensor.CompositeTensor):
|
||||
elif tf_utils.is_extension_type(tensor):
|
||||
# Dtype casting.
|
||||
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
||||
|
||||
|
@ -20,7 +20,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend
|
||||
@ -183,8 +182,8 @@ class InputLayer(base_layer.Layer):
|
||||
node_module.Node(layer=self, outputs=input_tensor)
|
||||
|
||||
# Store type spec
|
||||
if isinstance(input_tensor, (
|
||||
composite_tensor.CompositeTensor, keras_tensor.KerasTensor)):
|
||||
if isinstance(input_tensor, keras_tensor.KerasTensor) or (
|
||||
tf_utils.is_extension_type(input_tensor)):
|
||||
self._type_spec = input_tensor._type_spec # pylint: disable=protected-access
|
||||
else:
|
||||
self._type_spec = tensor_spec.TensorSpec(
|
||||
|
@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import composite_tensor_utils
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
@ -57,6 +56,7 @@ from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -2378,7 +2378,7 @@ class Model(training_lib.Model):
|
||||
|
||||
def _type_spec_from_value(value):
|
||||
"""Grab type_spec without converting array-likes to tensors."""
|
||||
if isinstance(value, composite_tensor.CompositeTensor):
|
||||
if tf_utils.is_extension_type(value):
|
||||
return value._type_spec # pylint: disable=protected-access
|
||||
# Get a TensorSpec for array-like data without
|
||||
# converting the data to a Tensor
|
||||
|
@ -284,6 +284,23 @@ def are_all_symbolic_tensors(tensors):
|
||||
_user_convertible_tensor_types = set()
|
||||
|
||||
|
||||
def is_extension_type(tensor):
|
||||
"""Returns whether a tensor is of an ExtensionType.
|
||||
|
||||
github.com/tensorflow/community/pull/269
|
||||
Currently it works by checking if `tensor` is a `CompositeTensor` instance,
|
||||
but this will be changed to use an appropriate extensiontype protocol
|
||||
check once ExtensionType is made public.
|
||||
|
||||
Arguments:
|
||||
tensor: An object to test
|
||||
|
||||
Returns:
|
||||
True if the tensor is an extension type object, false if not.
|
||||
"""
|
||||
return isinstance(tensor, composite_tensor.CompositeTensor)
|
||||
|
||||
|
||||
def is_symbolic_tensor(tensor):
|
||||
"""Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
|
||||
|
||||
@ -298,7 +315,7 @@ def is_symbolic_tensor(tensor):
|
||||
"""
|
||||
if isinstance(tensor, ops.Tensor):
|
||||
return hasattr(tensor, 'graph')
|
||||
elif isinstance(tensor, composite_tensor.CompositeTensor):
|
||||
elif is_extension_type(tensor):
|
||||
component_tensors = nest.flatten(tensor, expand_composites=True)
|
||||
return any(hasattr(t, 'graph') for t in component_tensors)
|
||||
elif isinstance(tensor, variables.Variable):
|
||||
@ -351,7 +368,7 @@ def register_symbolic_tensor_type(cls):
|
||||
|
||||
def type_spec_from_value(value):
|
||||
"""Grab type_spec without converting array-likes to tensors."""
|
||||
if isinstance(value, composite_tensor.CompositeTensor):
|
||||
if is_extension_type(value):
|
||||
return value._type_spec # pylint: disable=protected-access
|
||||
# Get a TensorSpec for array-like data without
|
||||
# converting the data to a Tensor
|
||||
@ -441,7 +458,7 @@ def get_tensor_spec(t, dynamic_batch=False, name=None):
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(t, type_spec.TypeSpec):
|
||||
spec = t
|
||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
||||
elif is_extension_type(t):
|
||||
# TODO(b/148821952): Should these specs have a name attr?
|
||||
spec = t._type_spec
|
||||
elif (hasattr(t, '_keras_history') and
|
||||
|
@ -22,11 +22,14 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -200,5 +203,24 @@ class TestIsRagged(test.TestCase):
|
||||
tensor = [1., 2., 3.]
|
||||
self.assertFalse(tf_utils.is_ragged(tensor))
|
||||
|
||||
|
||||
class TestIsExtensionType(test.TestCase):
|
||||
|
||||
def test_is_extension_type_return_true_for_ragged_tensor(self):
|
||||
self.assertTrue(tf_utils.is_extension_type(
|
||||
ragged_factory_ops.constant([[1, 2], [3]])))
|
||||
|
||||
def test_is_extension_type_return_true_for_sparse_tensor(self):
|
||||
self.assertTrue(tf_utils.is_extension_type(
|
||||
sparse_ops.from_dense([[1, 2], [3, 4]])))
|
||||
|
||||
def test_is_extension_type_return_false_for_dense_tensor(self):
|
||||
self.assertFalse(tf_utils.is_extension_type(
|
||||
constant_op.constant([[1, 2], [3, 4]])))
|
||||
|
||||
def test_is_extension_type_return_false_for_list(self):
|
||||
tensor = [1., 2., 3.]
|
||||
self.assertFalse(tf_utils.is_extension_type(tensor))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user