Changing the behavior of tensor_util.is_tensor(x)
not to return True if x
is a subclass of CompositeTensro.
PiperOrigin-RevId: 246893379
This commit is contained in:
parent
a83582dc61
commit
4e1b2f4d8f
@ -3116,7 +3116,7 @@ class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
|
|||||||
"""Tests that a user can register a CompositeTensor converter."""
|
"""Tests that a user can register a CompositeTensor converter."""
|
||||||
x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
|
x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
|
||||||
y = ops.convert_to_tensor_or_composite(x)
|
y = ops.convert_to_tensor_or_composite(x)
|
||||||
self.assertTrue(tensor_util.is_tensor(y))
|
self.assertFalse(tensor_util.is_tensor(y))
|
||||||
self.assertIsInstance(y, _TupleTensor)
|
self.assertIsInstance(y, _TupleTensor)
|
||||||
self.assertLen(y, len(x))
|
self.assertLen(y, len(x))
|
||||||
for x_, y_ in zip(x, y):
|
for x_, y_ in zip(x, y):
|
||||||
|
@ -175,7 +175,7 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
|||||||
@property
|
@property
|
||||||
def op(self):
|
def op(self):
|
||||||
"""The `Operation` that produces `values` as an output."""
|
"""The `Operation` that produces `values` as an output."""
|
||||||
return self.values.op
|
return self._values.op
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
|
@ -22,7 +22,6 @@ import six
|
|||||||
|
|
||||||
from tensorflow.core.framework import tensor_pb2
|
from tensorflow.core.framework import tensor_pb2
|
||||||
from tensorflow.core.framework import tensor_shape_pb2
|
from tensorflow.core.framework import tensor_shape_pb2
|
||||||
from tensorflow.python.framework import composite_tensor
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
@ -939,20 +938,16 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
|
|||||||
|
|
||||||
@tf_export("is_tensor")
|
@tf_export("is_tensor")
|
||||||
def is_tensor(x): # pylint: disable=invalid-name
|
def is_tensor(x): # pylint: disable=invalid-name
|
||||||
"""Check whether `x` is of tensor type.
|
"""Checks whether `x` is a tensor or "tensor-like".
|
||||||
|
|
||||||
Check whether an object is a tensor or a composite tensor. This check is
|
If `is_tensor(x)` returns `True`, it is safe to assume that `x` is a tensor or
|
||||||
equivalent to calling
|
can be converted to a tensor using `ops.convert_to_tensor(x)`.
|
||||||
`isinstance(x, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor, tf.Variable))`
|
|
||||||
and also checks if all the component variables of a MirroredVariable or a
|
|
||||||
SyncOnReadVariable are tensors.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: A python object to check.
|
x: A python object to check.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`True` if `x` is a tensor, `False` if not.
|
`True` if `x` is a tensor or "tensor-like", `False` if not.
|
||||||
"""
|
"""
|
||||||
return (isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) or # pylint: disable=protected-access
|
return (isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) or # pylint: disable=protected-access
|
||||||
isinstance(x, composite_tensor.CompositeTensor) or
|
getattr(x, "is_tensor_like", False))
|
||||||
(hasattr(x, "is_tensor_like") and x.is_tensor_like))
|
|
||||||
|
@ -479,7 +479,7 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
|
|||||||
|
|
||||||
|
|
||||||
def _make_output_composite_tensors_match(true_graph, false_graph):
|
def _make_output_composite_tensors_match(true_graph, false_graph):
|
||||||
"""Rewrites {true,false}_graph's outputs to use the same _TensorLike classes.
|
"""Modifies true_graph and false_graph so they have the same output signature.
|
||||||
|
|
||||||
Currently the only transformation implemented is turning a Tensor into an
|
Currently the only transformation implemented is turning a Tensor into an
|
||||||
equivalent IndexedSlices if the other branch returns an IndexedSlices.
|
equivalent IndexedSlices if the other branch returns an IndexedSlices.
|
||||||
|
@ -40,7 +40,6 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
_eval_using_default_session = ops._eval_using_default_session
|
_eval_using_default_session = ops._eval_using_default_session
|
||||||
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
|
Loading…
Reference in New Issue
Block a user