Make core.Tensor the base type for Tensor and replace the register_dense_tensor_like with direct subclassing.

PiperOrigin-RevId: 311206817
Change-Id: Id8ae234516d5409d6b70612a99f9f0b3ed53dc7e
This commit is contained in:
Dan Moldovan 2020-05-12 14:50:15 -07:00 committed by TensorFlower Gardener
parent 1e07fa6448
commit c5caa29b5e
17 changed files with 52 additions and 133 deletions

View File

@ -530,6 +530,13 @@ package_group(name = "ndarray_tensor_allow_list")
# TODO(b/154762408) Remove this package group once it's no longer needed.
package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove.
package_group(
name = "types_whitelist",
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
filegroup(
name = "intel_binary_blob",
data = if_mkl_ml(

View File

@ -230,6 +230,7 @@ py_library(
"//tensorflow/python/tools:module_util",
"//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python/tpu:tpu_noestimator",
"//tensorflow/python/types",
"//third_party/py/numpy",
],
)

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.types import core
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -422,7 +423,8 @@ class DistributedVarOp(object):
return hash((self.name, self.graph, self.traceback, self.type))
class DistributedVariable(DistributedDelegate, variables_lib.Variable):
class DistributedVariable(DistributedDelegate, variables_lib.Variable,
core.Tensor):
"""Holds a map from replica to variables."""
# TODO(josh11b): Support changing the set of variables if e.g. if new
@ -741,9 +743,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
pass
ops.register_dense_tensor_like_type(DistributedVariable)
def _validate_colocate_extended(v, extended):
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
if variable_strategy.extended is not extended:
@ -1380,7 +1379,7 @@ def value_container(val):
return val
class AggregatingVariable(variables_lib.Variable):
class AggregatingVariable(variables_lib.Variable, core.Tensor):
"""A wrapper around a variable that aggregates updates across replicas."""
def __init__(self, strategy, v, aggregation):
@ -1649,4 +1648,3 @@ def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
ops.register_tensor_conversion_function(AggregatingVariable,
_tensor_conversion_aggregate)
ops.register_dense_tensor_like_type(AggregatingVariable)

View File

@ -56,6 +56,7 @@ from tensorflow.python.saved_model.model_utils import mode_keys
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.types import core
from tensorflow.python.util import nest
@ -623,10 +624,10 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
# In cross replica context.
self.assertTrue(ops.is_dense_tensor_like(v))
self.assertIsInstance(v, core.Tensor)
# In replica context.
distribution.run(
lambda v: self.assertTrue(ops.is_dense_tensor_like(v)), args=(v,))
lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
aggregation):
@ -645,9 +646,9 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
# values is not allowed when aggregation is SUM. See
# `cross_device_ops.reduce_non_distributed_value`.
delta = array_ops.identity(1.)
self.assertTrue(ops.is_dense_tensor_like(v.assign(delta)))
self.assertTrue(ops.is_dense_tensor_like(v.assign_sub(delta)))
self.assertTrue(ops.is_dense_tensor_like(v.assign_add(delta)))
self.assertIsInstance(v.assign(delta), core.Tensor)
self.assertIsInstance(v.assign_sub(delta), core.Tensor)
self.assertIsInstance(v.assign_add(delta), core.Tensor)
# In cross replica context we return a PerReplica which is not Tensor like
# yet.

View File

@ -62,6 +62,7 @@ from tensorflow.python.framework import versions
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import core as core_tf_types
from tensorflow.python.types import internal
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
@ -213,53 +214,11 @@ def _as_graph_element(obj):
return None
_TENSOR_LIKE_TYPES = tuple()
# Deprecated - do not use.
# This API to avoid breaking estimator and tensorflow-mesh which depend on this
# internal API. The stub should be safe to use after TF 2.3 is released.
def is_dense_tensor_like(t):
"""EXPERIMENTAL: Returns true if `t` implements the tensor interface.
See `register_dense_tensor_like_type()` for the current definition of a
"tensor-like type".
Args:
t: An object.
Returns:
True iff `t` is an instance of one of the registered "tensor-like" types.
"""
return isinstance(t, _TENSOR_LIKE_TYPES)
def register_dense_tensor_like_type(tensor_type):
"""EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
A "tensor-like type" can represent a single dense tensor, and implements
the `name`, `dtype` and `shape` properties.
Args:
tensor_type: A type implementing the tensor interface.
Raises:
TypeError: If `tensor_type` does not implement the tensor interface.
"""
if not (hasattr(tensor_type, "name") and
isinstance(tensor_type.name, property)):
raise TypeError("Type %s does not define a `name` property" %
tensor_type.__name__)
if not (hasattr(tensor_type, "dtype") and
isinstance(tensor_type.dtype, property)):
raise TypeError("Type %s does not define a `dtype` property" %
tensor_type.__name__)
if not (hasattr(tensor_type, "shape") and
isinstance(tensor_type.shape, property)):
raise TypeError("Type %s does not define a `shape` property" %
tensor_type.__name__)
# We expect this list to be small, so choose quadratic complexity
# for registration, so that we have a tuple that can be used for
# more efficient `isinstance` checks later.
global _TENSOR_LIKE_TYPES
_TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
return isinstance(t, core_tf_types.Tensor)
def uid():
@ -304,7 +263,7 @@ def disable_tensor_equality():
# TODO(mdan): This object should subclass Symbol, not just Tensor.
@tf_export("Tensor")
class Tensor(internal.NativeObject):
class Tensor(internal.NativeObject, core_tf_types.Tensor):
"""A tensor is a multidimensional array of elements represented by a
`tf.Tensor` object. All elements are of a single known data type.
@ -1305,9 +1264,6 @@ class _EagerTensorBase(Tensor):
EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)
register_dense_tensor_like_type(Tensor)
@tf_export(v1=["convert_to_tensor"])
def convert_to_tensor_v1(value,
dtype=None,

View File

@ -3268,56 +3268,6 @@ class DeprecatedTest(test_util.TensorFlowTestCase):
test_ops.old()
class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
@test_util.disable_tfrt("Graph is not supported yet.")
def testSuccess(self):
op = ops.Operation(
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
t = op.outputs[0]
self.assertTrue(ops.is_dense_tensor_like(t))
v = variables.Variable([17])
self.assertTrue(ops.is_dense_tensor_like(v))
class BadClassNoName(object):
pass
class BadClassBadName(object):
def name(self):
pass
class BadClassNoDtype(object):
@property
def name(self):
pass
class BadClassBadDtype(object):
@property
def name(self):
pass
def dtype(self):
pass
def testBadClass(self):
with self.assertRaisesRegexp(TypeError, "`name`"):
ops.register_dense_tensor_like_type(
DenseTensorLikeTypeTest.BadClassNoName)
with self.assertRaisesRegexp(TypeError, "`name`"):
ops.register_dense_tensor_like_type(
DenseTensorLikeTypeTest.BadClassBadName)
with self.assertRaisesRegexp(TypeError, "`dtype`"):
ops.register_dense_tensor_like_type(
DenseTensorLikeTypeTest.BadClassNoDtype)
with self.assertRaisesRegexp(TypeError, "`dtype`"):
ops.register_dense_tensor_like_type(
DenseTensorLikeTypeTest.BadClassBadDtype)
class NameScopeTest(test_util.TensorFlowTestCase):
def testStripAndPrependScope(self):

View File

@ -26,6 +26,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.types import core
from tensorflow.python.types import internal
from tensorflow.python.util import compat
from tensorflow.python.util import nest
@ -1009,7 +1010,7 @@ def is_tensor(x): # pylint: disable=invalid-name
`True` if `x` is a tensor or "tensor-like", `False` if not.
"""
return (isinstance(x, internal.NativeObject) or
ops.is_dense_tensor_like(x) or
isinstance(x, core.Tensor) or
getattr(x, "is_tensor_like", False))

View File

@ -62,6 +62,7 @@ from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.types import core
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@ -3143,7 +3144,7 @@ def _convert_scipy_sparse_tensor(value, expected_input):
The possibly-converted 'value'.
"""
if issparse is not None and issparse(value):
if ops.is_dense_tensor_like(expected_input):
if isinstance(expected_input, core.Tensor):
if ops.executing_eagerly_outside_functions():
# In TF2 we do not silently densify sparse matrices.
raise ValueError('A SciPy sparse matrix was passed to a model '

View File

@ -23,9 +23,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.types import core
class AutoCastVariable(variables.Variable):
class AutoCastVariable(variables.Variable, core.Tensor):
"""Variable that will cast itself to a different dtype in applicable contexts.
This class wraps a floating-point `tf.Variable`. It emulates the variable
@ -417,7 +418,6 @@ class AutoCastVariable(variables.Variable):
ops.register_tensor_conversion_function(AutoCastVariable,
AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access
ops.register_dense_tensor_like_type(AutoCastVariable)
def create_autocast_variable(variable):

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_array_ops import *
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
from tensorflow.python.types import core
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
@ -1381,13 +1382,13 @@ def _autopacking_helper(list_or_tuple, dtype, name):
if context.executing_eagerly():
# NOTE: Fast path when all the items are tensors, this doesn't do any type
# checking.
if all(ops.is_dense_tensor_like(elem) for elem in list_or_tuple):
if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
return gen_array_ops.pack(list_or_tuple, name=name)
must_pack = False
converted_elems = []
with ops.name_scope(name) as scope:
for i, elem in enumerate(list_or_tuple):
if ops.is_dense_tensor_like(elem):
if isinstance(elem, core.Tensor):
if dtype is not None and elem.dtype.base_dtype != dtype:
raise TypeError("Cannot convert a list containing a tensor of dtype "
"%s to %s (Tensor is: %r)" %
@ -1396,7 +1397,7 @@ def _autopacking_helper(list_or_tuple, dtype, name):
must_pack = True
elif isinstance(elem, (list, tuple)):
converted_elem = _autopacking_helper(elem, dtype, str(i))
if ops.is_dense_tensor_like(converted_elem):
if isinstance(converted_elem, core.Tensor):
must_pack = True
converted_elems.append(converted_elem)
else:
@ -1404,7 +1405,7 @@ def _autopacking_helper(list_or_tuple, dtype, name):
if must_pack:
elems_as_tensors = []
for i, elem in enumerate(converted_elems):
if ops.is_dense_tensor_like(elem):
if isinstance(elem, core.Tensor):
elems_as_tensors.append(elem)
else:
# NOTE(mrry): This is inefficient, but it enables us to
@ -1429,7 +1430,7 @@ def _get_dtype_from_nested_lists(list_or_tuple):
such object exists.
"""
for elem in list_or_tuple:
if ops.is_dense_tensor_like(elem):
if isinstance(elem, core.Tensor):
return elem.dtype.base_dtype
elif isinstance(elem, (list, tuple)):
maybe_dtype = _get_dtype_from_nested_lists(elem)
@ -1441,7 +1442,7 @@ def _get_dtype_from_nested_lists(list_or_tuple):
def _cast_nested_seqs_to_dtype(dtype):
def _maybe_cast(elem):
if ops.is_dense_tensor_like(elem):
if isinstance(elem, core.Tensor):
if dtype != elem.dtype.base_dtype:
elem = gen_math_ops.cast(elem, dtype)
return elem
@ -1455,7 +1456,7 @@ _NON_AUTOPACKABLE_TYPES.add(np.ndarray)
def _should_not_autopack(v):
# The condition we really want is
# ops.is_dense_tensor_like(...)
# any(isinstance(elem, core.Tensor))
# but it is >5x slower due to abc.ABCMeta.__instancecheck__.
# pylint: disable=unidiomatic-typecheck
# TODO(slebedev): add nest.all?

View File

@ -49,6 +49,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops.gen_resource_variable_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.types import core
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
@ -330,7 +331,7 @@ def variable_accessed(variable):
tape.variable_accessed(variable)
class BaseResourceVariable(variables.VariableV1):
class BaseResourceVariable(variables.VariableV1, core.Tensor):
"""A python variable from an existing handle."""
# TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
@ -1830,7 +1831,6 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
# allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(BaseResourceVariable,
_dense_var_to_tensor)
ops.register_dense_tensor_like_type(BaseResourceVariable)
class _UnreadVariable(BaseResourceVariable):
@ -1955,9 +1955,6 @@ class _UnreadVariable(BaseResourceVariable):
return self._parent_op
ops.register_dense_tensor_like_type(_UnreadVariable)
@ops.RegisterGradient("ReadVariableOp")
def _ReadGrad(_, grad):
"""Gradient for read op."""

View File

@ -42,6 +42,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import core
from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
@ -1000,7 +1001,7 @@ class _VariableStore(object):
return initializer, initializing_from_value
class _LazyEvalTensor(object):
class _LazyEvalTensor(core.Tensor):
"""A Tensor-like object that only evaluates its thunk when used."""
def __init__(self, thunk):
@ -1069,8 +1070,6 @@ session.register_session_run_conversion_functions(
lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access
)
ops.register_dense_tensor_like_type(_LazyEvalTensor)
# To stop regularization, use this regularizer
@tf_export(v1=["no_regularizer"])

View File

@ -47,6 +47,7 @@ from tensorflow.python.util import tf_should_use
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.types import core
def default_variable_creator(_, **kwds):
@ -264,6 +265,7 @@ class VariableMetaclass(type):
@tf_export("Variable", v1=[])
# TODO(mdan): This should subclass core.Tensor, and not all its subclasses?
class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
"""See the [variable guide](https://tensorflow.org/guide/variable).
@ -1551,7 +1553,7 @@ class VariableV1(Variable):
# TODO(apassos): do not repeat all comments here
class RefVariable(VariableV1):
class RefVariable(VariableV1, core.Tensor):
"""Ref-based implementation of variables."""
def __init__(
@ -3032,7 +3034,6 @@ class PartitionedVariable(object):
# allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(RefVariable,
RefVariable._TensorConversionFunction) # pylint: disable=protected-access
ops.register_dense_tensor_like_type(RefVariable)
@tf_export(v1=["global_variables"])

View File

@ -226,6 +226,7 @@ py_library(
deps = [
"//tensorflow/python:util",
"//tensorflow/python/profiler/internal:_pywrap_traceme",
"//tensorflow/python/types",
"@six_archive//:six",
],
)

View File

@ -27,6 +27,9 @@ py_strict_library(
"internal.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
visibility = [
"//tensorflow:__subpackages__",
"//tensorflow:types_whitelist",
],
deps = [],
)

View File

@ -2,6 +2,7 @@ path: "tensorflow.Tensor"
tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
is_instance: "<type \'object\'>"
member {
name: "OVERLOADABLE_OPERATORS"

View File

@ -2,6 +2,7 @@ path: "tensorflow.Tensor"
tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
is_instance: "<type \'object\'>"
member {
name: "OVERLOADABLE_OPERATORS"