Clarifies meaning of is_tensor in the documentation.
Also clarifies documentation of TensorLike. Makes RaggedTensor subclass TensorLike. Also renames _TensorLike to TensorLike since we expect more usages of it. PiperOrigin-RevId: 303357719 Change-Id: Ieea5b0941e2149e2ad4670e449c72c471685cf37
This commit is contained in:
parent
6766a8ee3e
commit
d38475eb2e
@ -340,11 +340,8 @@ def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts):
|
||||
"""
|
||||
if tensor_util.is_tensor(iter_):
|
||||
if tensors.is_range_tensor(iter_):
|
||||
_tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
|
||||
symbol_names, opts)
|
||||
elif isinstance(iter_, ragged_tensor.RaggedTensor):
|
||||
_tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
|
||||
symbol_names, opts)
|
||||
_tf_range_for_stmt(
|
||||
iter_, extra_test, body, get_state, set_state, symbol_names, opts)
|
||||
else:
|
||||
_known_len_tf_for_stmt(
|
||||
iter_, extra_test, body, get_state, set_state, symbol_names, opts)
|
||||
|
@ -116,7 +116,7 @@ class MemoryCleanupTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
gc.collect()
|
||||
tensors = [
|
||||
o for o in gc.get_objects() if isinstance(o, tensor_like.TensorLike)
|
||||
o for o in gc.get_objects() if isinstance(o, tensor_like._TensorLike)
|
||||
]
|
||||
self.assertEmpty(tensors, "%d Tensors are still alive." % len(tensors))
|
||||
|
||||
|
@ -54,9 +54,13 @@ tensor_util = LazyLoader(
|
||||
"tensor_util", globals(),
|
||||
"tensorflow.python.framework.tensor_util")
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TensorLike = tensor_like._TensorLike
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@tf_export("IndexedSlices")
|
||||
class IndexedSlices(tensor_like.TensorLike, composite_tensor.CompositeTensor):
|
||||
class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
|
||||
"""A sparse representation of a set of tensor slices at given indices.
|
||||
|
||||
This class is a simple wrapper for a pair of `Tensor` objects:
|
||||
@ -305,7 +309,7 @@ def internal_convert_to_tensor_or_indexed_slices(value,
|
||||
"""
|
||||
if isinstance(value, ops.EagerTensor) and not context.executing_eagerly():
|
||||
return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
|
||||
elif isinstance(value, tensor_like.TensorLike):
|
||||
elif isinstance(value, _TensorLike):
|
||||
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
|
||||
raise ValueError(
|
||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
|
||||
|
@ -94,6 +94,7 @@ _api_usage_gauge = monitoring.BoolGauge(
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TensorLike = tensor_like._TensorLike
|
||||
_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@ -289,7 +290,7 @@ def disable_tensor_equality():
|
||||
|
||||
|
||||
@tf_export("Tensor")
|
||||
class Tensor(tensor_like.TensorLike):
|
||||
class Tensor(_TensorLike):
|
||||
"""A tensor is a multidimensional array of elements represented by a
|
||||
|
||||
`tf.Tensor` object. All elements are of a single known data type.
|
||||
@ -5982,12 +5983,9 @@ def _assert_same_graph(original_item, item):
|
||||
Raises:
|
||||
ValueError: if graphs do not match.
|
||||
"""
|
||||
original_graph = getattr(original_item, "graph", None)
|
||||
graph = getattr(item, "graph", None)
|
||||
if original_graph and graph and original_graph is not graph:
|
||||
raise ValueError(
|
||||
"%s must be from the same graph as %s (graphs are %s and %s)." %
|
||||
(item, original_item, graph, original_graph))
|
||||
if original_item.graph is not item.graph:
|
||||
raise ValueError("%s must be from the same graph as %s." %
|
||||
(item, original_item))
|
||||
|
||||
|
||||
def _get_graph_from_inputs(op_input_list, graph=None):
|
||||
@ -6041,16 +6039,16 @@ def _get_graph_from_inputs(op_input_list, graph=None):
|
||||
# TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
|
||||
# up.
|
||||
graph_element = None
|
||||
if (isinstance(op_input, (Operation, tensor_like.TensorLike)) and
|
||||
if (isinstance(op_input, (Operation, _TensorLike)) and
|
||||
((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck
|
||||
graph_element = op_input
|
||||
elif isinstance(op_input, Tensor):
|
||||
else:
|
||||
graph_element = _as_graph_element(op_input)
|
||||
|
||||
if graph_element is not None:
|
||||
if not graph:
|
||||
original_graph_element = graph_element
|
||||
graph = getattr(graph_element, "graph", None)
|
||||
graph = graph_element.graph
|
||||
elif original_graph_element is not None:
|
||||
_assert_same_graph(original_graph_element, graph_element)
|
||||
elif graph_element.graph is not graph:
|
||||
|
@ -38,13 +38,14 @@ from tensorflow.python.ops import gen_sparse_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TensorLike = tensor_like._TensorLike
|
||||
_eval_using_default_session = ops._eval_using_default_session
|
||||
_override_helper = ops._override_helper
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@tf_export("sparse.SparseTensor", "SparseTensor")
|
||||
class SparseTensor(tensor_like.TensorLike, composite_tensor.CompositeTensor):
|
||||
class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
||||
"""Represents a sparse tensor.
|
||||
|
||||
TensorFlow represents a sparse tensor as three separate dense tensors:
|
||||
|
@ -19,10 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class TensorLike(object):
|
||||
"""TF-specific types TF operations are expected to natively support.
|
||||
|
||||
Do not check this with isinstance directly; prefer instead using
|
||||
`tf.is_tensor` to check whether converting to a tensor is necessary.
|
||||
"""
|
||||
# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose.
|
||||
class _TensorLike(object):
|
||||
"""Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
|
||||
pass
|
||||
|
@ -978,23 +978,17 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
|
||||
|
||||
@tf_export("is_tensor")
|
||||
def is_tensor(x): # pylint: disable=invalid-name
|
||||
"""Checks whether `x` is a TF-native type that can be passed to many TF ops.
|
||||
"""Checks whether `x` is a tensor or "tensor-like".
|
||||
|
||||
Use is_tensor to differentiate types that can ingested by TensorFlow ops
|
||||
without any conversion (e.g., `tf.Tensor`, `tf.SparseTensor`, and
|
||||
`tf.RaggedTensor`) from types that need to be converted into tensors before
|
||||
they are ingested (e.g., numpy `ndarray` and Python scalars).
|
||||
If `is_tensor(x)` returns `True`, it is safe to assume that `x` is a tensor or
|
||||
can be converted to a tensor using `ops.convert_to_tensor(x)`.
|
||||
|
||||
For example, in the following code block:
|
||||
Usage example:
|
||||
|
||||
```python
|
||||
if not tf.is_tensor(t):
|
||||
t = tf.convert_to_tensor(t)
|
||||
return t.dtype
|
||||
```
|
||||
|
||||
we check to make sure that `t` is a tensor (and convert it if not) before
|
||||
accessing its `shape` and `dtype`.
|
||||
>>> tf.is_tensor(tf.constant([[1,2,3],[4,5,6],[7,8,9]]))
|
||||
True
|
||||
>>> tf.is_tensor("Hello World")
|
||||
False
|
||||
|
||||
Args:
|
||||
x: A python object to check.
|
||||
@ -1002,7 +996,8 @@ def is_tensor(x): # pylint: disable=invalid-name
|
||||
Returns:
|
||||
`True` if `x` is a tensor or "tensor-like", `False` if not.
|
||||
"""
|
||||
return (isinstance(x, (tensor_like.TensorLike, core.Tensor)) or
|
||||
# TODO(mdan): Remove these. Only keep core.Tensor.
|
||||
return (isinstance(x, (tensor_like._TensorLike, core.Tensor)) or # pylint: disable=protected-access
|
||||
ops.is_dense_tensor_like(x) or
|
||||
getattr(x, "is_tensor_like", False))
|
||||
|
||||
|
@ -56,7 +56,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.losses import util as tf_losses_utils
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
@ -1050,21 +1049,11 @@ def has_symbolic_tensors(ls):
|
||||
|
||||
|
||||
def has_tensors(ls):
|
||||
"""Returns true if `ls` contains tensors."""
|
||||
# Note: at some point in time ragged tensors didn't count as tensors, so this
|
||||
# returned false for ragged tensors. Making this return true fails some tests
|
||||
# which would then require a steps_per_epoch argument.
|
||||
if isinstance(ls, (list, tuple)):
|
||||
return any(
|
||||
tensor_util.is_tensor(v) and
|
||||
not isinstance(v, ragged_tensor.RaggedTensor) for v in ls)
|
||||
return any(tensor_util.is_tensor(v) for v in ls)
|
||||
if isinstance(ls, dict):
|
||||
return any(
|
||||
tensor_util.is_tensor(v) and
|
||||
not isinstance(v, ragged_tensor.RaggedTensor)
|
||||
for _, v in six.iteritems(ls))
|
||||
return tensor_util.is_tensor(ls) and not isinstance(
|
||||
ls, ragged_tensor.RaggedTensor)
|
||||
return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls))
|
||||
return tensor_util.is_tensor(ls)
|
||||
|
||||
|
||||
def get_metric_name(metric, weighted=False):
|
||||
|
@ -30,7 +30,6 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_like
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
@ -55,7 +54,7 @@ _convert_row_partition = RowPartition._convert_row_partition
|
||||
|
||||
|
||||
@tf_export("RaggedTensor")
|
||||
class RaggedTensor(composite_tensor.CompositeTensor, tensor_like.TensorLike):
|
||||
class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
"""Represents a ragged tensor.
|
||||
|
||||
A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are
|
||||
|
@ -1,7 +1,7 @@
|
||||
path: "tensorflow.IndexedSlices"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.indexed_slices.IndexedSlices\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -2,7 +2,6 @@ path: "tensorflow.RaggedTensor"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
|
@ -1,7 +1,7 @@
|
||||
path: "tensorflow.SparseTensor"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -1,7 +1,7 @@
|
||||
path: "tensorflow.Tensor"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "OVERLOADABLE_OPERATORS"
|
||||
|
@ -1,7 +1,7 @@
|
||||
path: "tensorflow.sparse.SparseTensor"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -1,7 +1,7 @@
|
||||
path: "tensorflow.IndexedSlices"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.indexed_slices.IndexedSlices\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -2,7 +2,6 @@ path: "tensorflow.RaggedTensor"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
|
@ -1,7 +1,7 @@
|
||||
path: "tensorflow.SparseTensor"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -1,7 +1,7 @@
|
||||
path: "tensorflow.Tensor"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "OVERLOADABLE_OPERATORS"
|
||||
|
@ -1,7 +1,7 @@
|
||||
path: "tensorflow.sparse.SparseTensor"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like.TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_like._TensorLike\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
Loading…
Reference in New Issue
Block a user