Switch all CompositeTensor instance checks in Keras to use a centralized tf_utils.is_extension_type
util. This util will use the public ExtensionType api once it is in place.
PiperOrigin-RevId: 332971935 Change-Id: Ic73743d70b2e11e431262d209ac1fd8666570309
This commit is contained in:
parent
e90dd8abe7
commit
5d5534edf7
tensorflow/python/keras
@ -41,7 +41,6 @@ from tensorflow.python.distribute import distribution_strategy_context
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function as eager_function
|
from tensorflow.python.eager import function as eager_function
|
||||||
from tensorflow.python.eager import lift_to_graph
|
from tensorflow.python.eager import lift_to_graph
|
||||||
from tensorflow.python.framework import composite_tensor
|
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device_spec
|
from tensorflow.python.framework import device_spec
|
||||||
@ -1268,7 +1267,8 @@ def is_placeholder(x):
|
|||||||
try:
|
try:
|
||||||
if keras_tensor.keras_tensors_enabled():
|
if keras_tensor.keras_tensors_enabled():
|
||||||
return hasattr(x, '_is_backend_placeholder')
|
return hasattr(x, '_is_backend_placeholder')
|
||||||
if isinstance(x, composite_tensor.CompositeTensor):
|
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
|
||||||
|
if tf_utils.is_extension_type(x):
|
||||||
flat_components = nest.flatten(x, expand_composites=True)
|
flat_components = nest.flatten(x, expand_composites=True)
|
||||||
return py_any(is_placeholder(c) for c in flat_components)
|
return py_any(is_placeholder(c) for c in flat_components)
|
||||||
else:
|
else:
|
||||||
@ -3881,7 +3881,8 @@ class GraphExecutionFunction(object):
|
|||||||
# CompositeTensors. E.g., if output_structure contains a SparseTensor, then
|
# CompositeTensors. E.g., if output_structure contains a SparseTensor, then
|
||||||
# this ensures that we return its value as a SparseTensorValue rather than
|
# this ensures that we return its value as a SparseTensorValue rather than
|
||||||
# a SparseTensor.
|
# a SparseTensor.
|
||||||
if isinstance(tensor, composite_tensor.CompositeTensor):
|
from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top
|
||||||
|
if tf_utils.is_extension_type(tensor):
|
||||||
return self._session.run(tensor)
|
return self._session.run(tensor)
|
||||||
else:
|
else:
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -74,6 +74,7 @@ py_library(
|
|||||||
"//tensorflow/python/keras/utils:engine_utils",
|
"//tensorflow/python/keras/utils:engine_utils",
|
||||||
"//tensorflow/python/keras/utils:metrics_utils",
|
"//tensorflow/python/keras/utils:metrics_utils",
|
||||||
"//tensorflow/python/keras/utils:mode_keys",
|
"//tensorflow/python/keras/utils:mode_keys",
|
||||||
|
"//tensorflow/python/keras/utils:tf_utils",
|
||||||
"//tensorflow/python/keras/utils:version_utils",
|
"//tensorflow/python/keras/utils:version_utils",
|
||||||
"//tensorflow/python/module",
|
"//tensorflow/python/module",
|
||||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||||
@ -178,6 +179,7 @@ py_library(
|
|||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
"//tensorflow/python/keras/utils:engine_utils",
|
"//tensorflow/python/keras/utils:engine_utils",
|
||||||
|
"//tensorflow/python/keras/utils:tf_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,10 +40,10 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import smart_cond
|
from tensorflow.python.framework import smart_cond
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework.ops import composite_tensor
|
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.engine import training_utils
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.utils import data_utils
|
from tensorflow.python.keras.utils import data_utils
|
||||||
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
@ -527,7 +527,7 @@ class CompositeTensorDataAdapter(DataAdapter):
|
|||||||
|
|
||||||
def _is_composite(v):
|
def _is_composite(v):
|
||||||
# Dataset inherits from CompositeTensor but shouldn't be handled here.
|
# Dataset inherits from CompositeTensor but shouldn't be handled here.
|
||||||
if (isinstance(v, composite_tensor.CompositeTensor) and
|
if (tf_utils.is_extension_type(v) and
|
||||||
not isinstance(v, dataset_ops.DatasetV2)):
|
not isinstance(v, dataset_ops.DatasetV2)):
|
||||||
return True
|
return True
|
||||||
# Support Scipy sparse tensors if scipy is installed
|
# Support Scipy sparse tensors if scipy is installed
|
||||||
|
@ -27,7 +27,6 @@ import warnings
|
|||||||
from six.moves import zip # pylint: disable=redefined-builtin
|
from six.moves import zip # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import composite_tensor
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
@ -641,7 +640,7 @@ class Functional(training_lib.Model):
|
|||||||
|
|
||||||
# Dtype casting.
|
# Dtype casting.
|
||||||
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
||||||
elif isinstance(tensor, composite_tensor.CompositeTensor):
|
elif tf_utils.is_extension_type(tensor):
|
||||||
# Dtype casting.
|
# Dtype casting.
|
||||||
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.framework import composite_tensor
|
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
@ -183,8 +182,8 @@ class InputLayer(base_layer.Layer):
|
|||||||
node_module.Node(layer=self, outputs=input_tensor)
|
node_module.Node(layer=self, outputs=input_tensor)
|
||||||
|
|
||||||
# Store type spec
|
# Store type spec
|
||||||
if isinstance(input_tensor, (
|
if isinstance(input_tensor, keras_tensor.KerasTensor) or (
|
||||||
composite_tensor.CompositeTensor, keras_tensor.KerasTensor)):
|
tf_utils.is_extension_type(input_tensor)):
|
||||||
self._type_spec = input_tensor._type_spec # pylint: disable=protected-access
|
self._type_spec = input_tensor._type_spec # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
self._type_spec = tensor_spec.TensorSpec(
|
self._type_spec = tensor_spec.TensorSpec(
|
||||||
|
@ -29,7 +29,6 @@ from tensorflow.python.distribute import distribution_strategy_context
|
|||||||
from tensorflow.python.distribute import parameter_server_strategy
|
from tensorflow.python.distribute import parameter_server_strategy
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import composite_tensor
|
|
||||||
from tensorflow.python.framework import composite_tensor_utils
|
from tensorflow.python.framework import composite_tensor_utils
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -57,6 +56,7 @@ from tensorflow.python.keras.utils import data_utils
|
|||||||
from tensorflow.python.keras.utils import layer_utils
|
from tensorflow.python.keras.utils import layer_utils
|
||||||
from tensorflow.python.keras.utils import losses_utils
|
from tensorflow.python.keras.utils import losses_utils
|
||||||
from tensorflow.python.keras.utils import tf_inspect
|
from tensorflow.python.keras.utils import tf_inspect
|
||||||
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -2378,7 +2378,7 @@ class Model(training_lib.Model):
|
|||||||
|
|
||||||
def _type_spec_from_value(value):
|
def _type_spec_from_value(value):
|
||||||
"""Grab type_spec without converting array-likes to tensors."""
|
"""Grab type_spec without converting array-likes to tensors."""
|
||||||
if isinstance(value, composite_tensor.CompositeTensor):
|
if tf_utils.is_extension_type(value):
|
||||||
return value._type_spec # pylint: disable=protected-access
|
return value._type_spec # pylint: disable=protected-access
|
||||||
# Get a TensorSpec for array-like data without
|
# Get a TensorSpec for array-like data without
|
||||||
# converting the data to a Tensor
|
# converting the data to a Tensor
|
||||||
|
@ -284,6 +284,23 @@ def are_all_symbolic_tensors(tensors):
|
|||||||
_user_convertible_tensor_types = set()
|
_user_convertible_tensor_types = set()
|
||||||
|
|
||||||
|
|
||||||
|
def is_extension_type(tensor):
|
||||||
|
"""Returns whether a tensor is of an ExtensionType.
|
||||||
|
|
||||||
|
github.com/tensorflow/community/pull/269
|
||||||
|
Currently it works by checking if `tensor` is a `CompositeTensor` instance,
|
||||||
|
but this will be changed to use an appropriate extensiontype protocol
|
||||||
|
check once ExtensionType is made public.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
tensor: An object to test
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the tensor is an extension type object, false if not.
|
||||||
|
"""
|
||||||
|
return isinstance(tensor, composite_tensor.CompositeTensor)
|
||||||
|
|
||||||
|
|
||||||
def is_symbolic_tensor(tensor):
|
def is_symbolic_tensor(tensor):
|
||||||
"""Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
|
"""Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
|
||||||
|
|
||||||
@ -298,7 +315,7 @@ def is_symbolic_tensor(tensor):
|
|||||||
"""
|
"""
|
||||||
if isinstance(tensor, ops.Tensor):
|
if isinstance(tensor, ops.Tensor):
|
||||||
return hasattr(tensor, 'graph')
|
return hasattr(tensor, 'graph')
|
||||||
elif isinstance(tensor, composite_tensor.CompositeTensor):
|
elif is_extension_type(tensor):
|
||||||
component_tensors = nest.flatten(tensor, expand_composites=True)
|
component_tensors = nest.flatten(tensor, expand_composites=True)
|
||||||
return any(hasattr(t, 'graph') for t in component_tensors)
|
return any(hasattr(t, 'graph') for t in component_tensors)
|
||||||
elif isinstance(tensor, variables.Variable):
|
elif isinstance(tensor, variables.Variable):
|
||||||
@ -351,7 +368,7 @@ def register_symbolic_tensor_type(cls):
|
|||||||
|
|
||||||
def type_spec_from_value(value):
|
def type_spec_from_value(value):
|
||||||
"""Grab type_spec without converting array-likes to tensors."""
|
"""Grab type_spec without converting array-likes to tensors."""
|
||||||
if isinstance(value, composite_tensor.CompositeTensor):
|
if is_extension_type(value):
|
||||||
return value._type_spec # pylint: disable=protected-access
|
return value._type_spec # pylint: disable=protected-access
|
||||||
# Get a TensorSpec for array-like data without
|
# Get a TensorSpec for array-like data without
|
||||||
# converting the data to a Tensor
|
# converting the data to a Tensor
|
||||||
@ -441,7 +458,7 @@ def get_tensor_spec(t, dynamic_batch=False, name=None):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if isinstance(t, type_spec.TypeSpec):
|
if isinstance(t, type_spec.TypeSpec):
|
||||||
spec = t
|
spec = t
|
||||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
elif is_extension_type(t):
|
||||||
# TODO(b/148821952): Should these specs have a name attr?
|
# TODO(b/148821952): Should these specs have a name attr?
|
||||||
spec = t._type_spec
|
spec = t._type_spec
|
||||||
elif (hasattr(t, '_keras_history') and
|
elif (hasattr(t, '_keras_history') and
|
||||||
|
@ -22,11 +22,14 @@ from absl.testing import parameterized
|
|||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.keras import combinations
|
from tensorflow.python.keras import combinations
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -200,5 +203,24 @@ class TestIsRagged(test.TestCase):
|
|||||||
tensor = [1., 2., 3.]
|
tensor = [1., 2., 3.]
|
||||||
self.assertFalse(tf_utils.is_ragged(tensor))
|
self.assertFalse(tf_utils.is_ragged(tensor))
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsExtensionType(test.TestCase):
|
||||||
|
|
||||||
|
def test_is_extension_type_return_true_for_ragged_tensor(self):
|
||||||
|
self.assertTrue(tf_utils.is_extension_type(
|
||||||
|
ragged_factory_ops.constant([[1, 2], [3]])))
|
||||||
|
|
||||||
|
def test_is_extension_type_return_true_for_sparse_tensor(self):
|
||||||
|
self.assertTrue(tf_utils.is_extension_type(
|
||||||
|
sparse_ops.from_dense([[1, 2], [3, 4]])))
|
||||||
|
|
||||||
|
def test_is_extension_type_return_false_for_dense_tensor(self):
|
||||||
|
self.assertFalse(tf_utils.is_extension_type(
|
||||||
|
constant_op.constant([[1, 2], [3, 4]])))
|
||||||
|
|
||||||
|
def test_is_extension_type_return_false_for_list(self):
|
||||||
|
tensor = [1., 2., 3.]
|
||||||
|
self.assertFalse(tf_utils.is_extension_type(tensor))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user