Clarifies meaning of is_tensor in the documentation.

Also clarifies documentation of TensorLike. Makes RaggedTensor subclass TensorLike.
Also renames _TensorLike to TensorLike since we expect more usages of it.

PiperOrigin-RevId: 303357719
Change-Id: Ieea5b0941e2149e2ad4670e449c72c471685cf37
This commit is contained in:
A. Unique TensorFlower 2020-03-27 10:29:03 -07:00 committed by TensorFlower Gardener
parent 6766a8ee3e
commit d38475eb2e
19 changed files with 44 additions and 66 deletions

View File

@ -340,11 +340,8 @@ def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts):
""" """
if tensor_util.is_tensor(iter_): if tensor_util.is_tensor(iter_):
if tensors.is_range_tensor(iter_): if tensors.is_range_tensor(iter_):
_tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, _tf_range_for_stmt(
symbol_names, opts) iter_, extra_test, body, get_state, set_state, symbol_names, opts)
elif isinstance(iter_, ragged_tensor.RaggedTensor):
_tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
symbol_names, opts)
else: else:
_known_len_tf_for_stmt( _known_len_tf_for_stmt(
iter_, extra_test, body, get_state, set_state, symbol_names, opts) iter_, extra_test, body, get_state, set_state, symbol_names, opts)

View File

@ -116,7 +116,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
gc.collect() gc.collect()
tensors = [ tensors = [
o for o in gc.get_objects() if isinstance(o, tensor_like.TensorLike) o for o in gc.get_objects() if isinstance(o, tensor_like._TensorLike)
] ]
self.assertEmpty(tensors, "%d Tensors are still alive." % len(tensors)) self.assertEmpty(tensors, "%d Tensors are still alive." % len(tensors))

View File

@ -54,9 +54,13 @@ tensor_util = LazyLoader(
"tensor_util", globals(), "tensor_util", globals(),
"tensorflow.python.framework.tensor_util") "tensorflow.python.framework.tensor_util")
# pylint: disable=protected-access
_TensorLike = tensor_like._TensorLike
# pylint: enable=protected-access
@tf_export("IndexedSlices") @tf_export("IndexedSlices")
class IndexedSlices(tensor_like.TensorLike, composite_tensor.CompositeTensor): class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
"""A sparse representation of a set of tensor slices at given indices. """A sparse representation of a set of tensor slices at given indices.
This class is a simple wrapper for a pair of `Tensor` objects: This class is a simple wrapper for a pair of `Tensor` objects:
@ -305,7 +309,7 @@ def internal_convert_to_tensor_or_indexed_slices(value,
""" """
if isinstance(value, ops.EagerTensor) and not context.executing_eagerly(): if isinstance(value, ops.EagerTensor) and not context.executing_eagerly():
return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref) return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
elif isinstance(value, tensor_like.TensorLike): elif isinstance(value, _TensorLike):
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
raise ValueError( raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %

View File

@ -94,6 +94,7 @@ _api_usage_gauge = monitoring.BoolGauge(
# pylint: disable=protected-access # pylint: disable=protected-access
_TensorLike = tensor_like._TensorLike
_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE _DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE
# pylint: enable=protected-access # pylint: enable=protected-access
@ -289,7 +290,7 @@ def disable_tensor_equality():
@tf_export("Tensor") @tf_export("Tensor")
class Tensor(tensor_like.TensorLike): class Tensor(_TensorLike):
"""A tensor is a multidimensional array of elements represented by a """A tensor is a multidimensional array of elements represented by a
`tf.Tensor` object. All elements are of a single known data type. `tf.Tensor` object. All elements are of a single known data type.
@ -5982,12 +5983,9 @@ def _assert_same_graph(original_item, item):
Raises: Raises:
ValueError: if graphs do not match. ValueError: if graphs do not match.
""" """
original_graph = getattr(original_item, "graph", None) if original_item.graph is not item.graph:
graph = getattr(item, "graph", None) raise ValueError("%s must be from the same graph as %s." %
if original_graph and graph and original_graph is not graph: (item, original_item))
raise ValueError(
"%s must be from the same graph as %s (graphs are %s and %s)." %
(item, original_item, graph, original_graph))
def _get_graph_from_inputs(op_input_list, graph=None): def _get_graph_from_inputs(op_input_list, graph=None):
@ -6041,16 +6039,16 @@ def _get_graph_from_inputs(op_input_list, graph=None):
# TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
# up. # up.
graph_element = None graph_element = None
if (isinstance(op_input, (Operation, tensor_like.TensorLike)) and if (isinstance(op_input, (Operation, _TensorLike)) and
((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck
graph_element = op_input graph_element = op_input
elif isinstance(op_input, Tensor): else:
graph_element = _as_graph_element(op_input) graph_element = _as_graph_element(op_input)
if graph_element is not None: if graph_element is not None:
if not graph: if not graph:
original_graph_element = graph_element original_graph_element = graph_element
graph = getattr(graph_element, "graph", None) graph = graph_element.graph
elif original_graph_element is not None: elif original_graph_element is not None:
_assert_same_graph(original_graph_element, graph_element) _assert_same_graph(original_graph_element, graph_element)
elif graph_element.graph is not graph: elif graph_element.graph is not graph:

View File

@ -38,13 +38,14 @@ from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access # pylint: disable=protected-access
_TensorLike = tensor_like._TensorLike
_eval_using_default_session = ops._eval_using_default_session _eval_using_default_session = ops._eval_using_default_session
_override_helper = ops._override_helper _override_helper = ops._override_helper
# pylint: enable=protected-access # pylint: enable=protected-access
@tf_export("sparse.SparseTensor", "SparseTensor") @tf_export("sparse.SparseTensor", "SparseTensor")
class SparseTensor(tensor_like.TensorLike, composite_tensor.CompositeTensor): class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
"""Represents a sparse tensor. """Represents a sparse tensor.
TensorFlow represents a sparse tensor as three separate dense tensors: TensorFlow represents a sparse tensor as three separate dense tensors:

View File

@ -19,10 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
class TensorLike(object): # NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose.
"""TF-specific types TF operations are expected to natively support. class _TensorLike(object):
"""Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
Do not check this with isinstance directly; prefer instead using
`tf.is_tensor` to check whether converting to a tensor is necessary.
"""
pass pass

View File

@ -978,23 +978,17 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
@tf_export("is_tensor") @tf_export("is_tensor")
def is_tensor(x): # pylint: disable=invalid-name def is_tensor(x): # pylint: disable=invalid-name
"""Checks whether `x` is a TF-native type that can be passed to many TF ops. """Checks whether `x` is a tensor or "tensor-like".
Use is_tensor to differentiate types that can ingested by TensorFlow ops If `is_tensor(x)` returns `True`, it is safe to assume that `x` is a tensor or
without any conversion (e.g., `tf.Tensor`, `tf.SparseTensor`, and can be converted to a tensor using `ops.convert_to_tensor(x)`.
`tf.RaggedTensor`) from types that need to be converted into tensors before
they are ingested (e.g., numpy `ndarray` and Python scalars).
For example, in the following code block: Usage example:
```python >>> tf.is_tensor(tf.constant([[1,2,3],[4,5,6],[7,8,9]]))
if not tf.is_tensor(t): True
t = tf.convert_to_tensor(t) >>> tf.is_tensor("Hello World")
return t.dtype False
```
we check to make sure that `t` is a tensor (and convert it if not) before
accessing its `shape` and `dtype`.
Args: Args:
x: A python object to check. x: A python object to check.
@ -1002,7 +996,8 @@ def is_tensor(x): # pylint: disable=invalid-name
Returns: Returns:
`True` if `x` is a tensor or "tensor-like", `False` if not. `True` if `x` is a tensor or "tensor-like", `False` if not.
""" """
return (isinstance(x, (tensor_like.TensorLike, core.Tensor)) or # TODO(mdan): Remove these. Only keep core.Tensor.
return (isinstance(x, (tensor_like._TensorLike, core.Tensor)) or # pylint: disable=protected-access
ops.is_dense_tensor_like(x) or ops.is_dense_tensor_like(x) or
getattr(x, "is_tensor_like", False)) getattr(x, "is_tensor_like", False))

View File

@ -56,7 +56,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.losses import util as tf_losses_utils from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
@ -1050,21 +1049,11 @@ def has_symbolic_tensors(ls):
def has_tensors(ls): def has_tensors(ls):
"""Returns true if `ls` contains tensors."""
# Note: at some point in time ragged tensors didn't count as tensors, so this
# returned false for ragged tensors. Making this return true fails some tests
# which would then require a steps_per_epoch argument.
if isinstance(ls, (list, tuple)): if isinstance(ls, (list, tuple)):
return any( return any(tensor_util.is_tensor(v) for v in ls)
tensor_util.is_tensor(v) and
not isinstance(v, ragged_tensor.RaggedTensor) for v in ls)
if isinstance(ls, dict): if isinstance(ls, dict):
return any( return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls))
tensor_util.is_tensor(v) and return tensor_util.is_tensor(ls)
not isinstance(v, ragged_tensor.RaggedTensor)
for _, v in six.iteritems(ls))
return tensor_util.is_tensor(ls) and not isinstance(
ls, ragged_tensor.RaggedTensor)
def get_metric_name(metric, weighted=False): def get_metric_name(metric, weighted=False):

View File

@ -30,7 +30,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_like
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec
@ -55,7 +54,7 @@ _convert_row_partition = RowPartition._convert_row_partition
@tf_export("RaggedTensor") @tf_export("RaggedTensor")
class RaggedTensor(composite_tensor.CompositeTensor, tensor_like.TensorLike): class RaggedTensor(composite_tensor.CompositeTensor):
"""Represents a ragged tensor. """Represents a ragged tensor.
A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are

View File

@ -1,7 +1,7 @@
path: "tensorflow.IndexedSlices" path: "tensorflow.IndexedSlices"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.indexed_slices.IndexedSlices\'>" is_instance: "<class \'tensorflow.python.framework.indexed_slices.IndexedSlices\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>" is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {

View File

@ -2,7 +2,6 @@ path: "tensorflow.RaggedTensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor\'>" is_instance: "<class \'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "dtype" name: "dtype"

View File

@ -1,7 +1,7 @@
path: "tensorflow.SparseTensor" path: "tensorflow.SparseTensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>" is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>" is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {

View File

@ -1,7 +1,7 @@
path: "tensorflow.Tensor" path: "tensorflow.Tensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>" is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>" is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "OVERLOADABLE_OPERATORS" name: "OVERLOADABLE_OPERATORS"

View File

@ -1,7 +1,7 @@
path: "tensorflow.sparse.SparseTensor" path: "tensorflow.sparse.SparseTensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>" is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>" is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {

View File

@ -1,7 +1,7 @@
path: "tensorflow.IndexedSlices" path: "tensorflow.IndexedSlices"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.indexed_slices.IndexedSlices\'>" is_instance: "<class \'tensorflow.python.framework.indexed_slices.IndexedSlices\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>" is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {

View File

@ -2,7 +2,6 @@ path: "tensorflow.RaggedTensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor\'>" is_instance: "<class \'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "dtype" name: "dtype"

View File

@ -1,7 +1,7 @@
path: "tensorflow.SparseTensor" path: "tensorflow.SparseTensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>" is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>" is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {

View File

@ -1,7 +1,7 @@
path: "tensorflow.Tensor" path: "tensorflow.Tensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>" is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>" is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "OVERLOADABLE_OPERATORS" name: "OVERLOADABLE_OPERATORS"

View File

@ -1,7 +1,7 @@
path: "tensorflow.sparse.SparseTensor" path: "tensorflow.sparse.SparseTensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>" is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>" is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {