Make core.Tensor
the base type for Tensor and replace the register_dense_tensor_like
with direct subclassing.
PiperOrigin-RevId: 311206817 Change-Id: Id8ae234516d5409d6b70612a99f9f0b3ed53dc7e
This commit is contained in:
parent
1e07fa6448
commit
c5caa29b5e
@ -530,6 +530,13 @@ package_group(name = "ndarray_tensor_allow_list")
|
|||||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||||
package_group(name = "composite_tensor_whitelist")
|
package_group(name = "composite_tensor_whitelist")
|
||||||
|
|
||||||
|
# Packages that use private types symbols, until they are exported.
|
||||||
|
# TODO(b/154650521) Remove.
|
||||||
|
package_group(
|
||||||
|
name = "types_whitelist",
|
||||||
|
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "intel_binary_blob",
|
name = "intel_binary_blob",
|
||||||
data = if_mkl_ml(
|
data = if_mkl_ml(
|
||||||
|
@ -230,6 +230,7 @@ py_library(
|
|||||||
"//tensorflow/python/tools:module_util",
|
"//tensorflow/python/tools:module_util",
|
||||||
"//tensorflow/python/tools/api/generator:create_python_api",
|
"//tensorflow/python/tools/api/generator:create_python_api",
|
||||||
"//tensorflow/python/tpu:tpu_noestimator",
|
"//tensorflow/python/tpu:tpu_noestimator",
|
||||||
|
"//tensorflow/python/types",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.ops import variables as variables_lib
|
|||||||
from tensorflow.python.training.saving import saveable_object
|
from tensorflow.python.training.saving import saveable_object
|
||||||
from tensorflow.python.training.saving import saveable_object_util
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
|
from tensorflow.python.types import core
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -422,7 +423,8 @@ class DistributedVarOp(object):
|
|||||||
return hash((self.name, self.graph, self.traceback, self.type))
|
return hash((self.name, self.graph, self.traceback, self.type))
|
||||||
|
|
||||||
|
|
||||||
class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||||
|
core.Tensor):
|
||||||
"""Holds a map from replica to variables."""
|
"""Holds a map from replica to variables."""
|
||||||
|
|
||||||
# TODO(josh11b): Support changing the set of variables if e.g. if new
|
# TODO(josh11b): Support changing the set of variables if e.g. if new
|
||||||
@ -741,9 +743,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
ops.register_dense_tensor_like_type(DistributedVariable)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_colocate_extended(v, extended):
|
def _validate_colocate_extended(v, extended):
|
||||||
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
|
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
|
||||||
if variable_strategy.extended is not extended:
|
if variable_strategy.extended is not extended:
|
||||||
@ -1380,7 +1379,7 @@ def value_container(val):
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
class AggregatingVariable(variables_lib.Variable):
|
class AggregatingVariable(variables_lib.Variable, core.Tensor):
|
||||||
"""A wrapper around a variable that aggregates updates across replicas."""
|
"""A wrapper around a variable that aggregates updates across replicas."""
|
||||||
|
|
||||||
def __init__(self, strategy, v, aggregation):
|
def __init__(self, strategy, v, aggregation):
|
||||||
@ -1649,4 +1648,3 @@ def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
|
|||||||
|
|
||||||
ops.register_tensor_conversion_function(AggregatingVariable,
|
ops.register_tensor_conversion_function(AggregatingVariable,
|
||||||
_tensor_conversion_aggregate)
|
_tensor_conversion_aggregate)
|
||||||
ops.register_dense_tensor_like_type(AggregatingVariable)
|
|
||||||
|
@ -56,6 +56,7 @@ from tensorflow.python.saved_model.model_utils import mode_keys
|
|||||||
from tensorflow.python.tpu import tpu_strategy_util
|
from tensorflow.python.tpu import tpu_strategy_util
|
||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
from tensorflow.python.training.tracking import util as trackable_utils
|
from tensorflow.python.training.tracking import util as trackable_utils
|
||||||
|
from tensorflow.python.types import core
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
@ -623,10 +624,10 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
v = variables_lib.Variable(
|
v = variables_lib.Variable(
|
||||||
0., synchronization=synchronization, aggregation=aggregation)
|
0., synchronization=synchronization, aggregation=aggregation)
|
||||||
# In cross replica context.
|
# In cross replica context.
|
||||||
self.assertTrue(ops.is_dense_tensor_like(v))
|
self.assertIsInstance(v, core.Tensor)
|
||||||
# In replica context.
|
# In replica context.
|
||||||
distribution.run(
|
distribution.run(
|
||||||
lambda v: self.assertTrue(ops.is_dense_tensor_like(v)), args=(v,))
|
lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
|
||||||
|
|
||||||
def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
|
def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
|
||||||
aggregation):
|
aggregation):
|
||||||
@ -645,9 +646,9 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
|||||||
# values is not allowed when aggregation is SUM. See
|
# values is not allowed when aggregation is SUM. See
|
||||||
# `cross_device_ops.reduce_non_distributed_value`.
|
# `cross_device_ops.reduce_non_distributed_value`.
|
||||||
delta = array_ops.identity(1.)
|
delta = array_ops.identity(1.)
|
||||||
self.assertTrue(ops.is_dense_tensor_like(v.assign(delta)))
|
self.assertIsInstance(v.assign(delta), core.Tensor)
|
||||||
self.assertTrue(ops.is_dense_tensor_like(v.assign_sub(delta)))
|
self.assertIsInstance(v.assign_sub(delta), core.Tensor)
|
||||||
self.assertTrue(ops.is_dense_tensor_like(v.assign_add(delta)))
|
self.assertIsInstance(v.assign_add(delta), core.Tensor)
|
||||||
|
|
||||||
# In cross replica context we return a PerReplica which is not Tensor like
|
# In cross replica context we return a PerReplica which is not Tensor like
|
||||||
# yet.
|
# yet.
|
||||||
|
@ -62,6 +62,7 @@ from tensorflow.python.framework import versions
|
|||||||
from tensorflow.python.ops import control_flow_util
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.platform import app
|
from tensorflow.python.platform import app
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.types import core as core_tf_types
|
||||||
from tensorflow.python.types import internal
|
from tensorflow.python.types import internal
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import decorator_utils
|
from tensorflow.python.util import decorator_utils
|
||||||
@ -213,53 +214,11 @@ def _as_graph_element(obj):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
_TENSOR_LIKE_TYPES = tuple()
|
# Deprecated - do not use.
|
||||||
|
# This API to avoid breaking estimator and tensorflow-mesh which depend on this
|
||||||
|
# internal API. The stub should be safe to use after TF 2.3 is released.
|
||||||
def is_dense_tensor_like(t):
|
def is_dense_tensor_like(t):
|
||||||
"""EXPERIMENTAL: Returns true if `t` implements the tensor interface.
|
return isinstance(t, core_tf_types.Tensor)
|
||||||
|
|
||||||
See `register_dense_tensor_like_type()` for the current definition of a
|
|
||||||
"tensor-like type".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
t: An object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True iff `t` is an instance of one of the registered "tensor-like" types.
|
|
||||||
"""
|
|
||||||
return isinstance(t, _TENSOR_LIKE_TYPES)
|
|
||||||
|
|
||||||
|
|
||||||
def register_dense_tensor_like_type(tensor_type):
|
|
||||||
"""EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
|
|
||||||
|
|
||||||
A "tensor-like type" can represent a single dense tensor, and implements
|
|
||||||
the `name`, `dtype` and `shape` properties.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor_type: A type implementing the tensor interface.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If `tensor_type` does not implement the tensor interface.
|
|
||||||
"""
|
|
||||||
if not (hasattr(tensor_type, "name") and
|
|
||||||
isinstance(tensor_type.name, property)):
|
|
||||||
raise TypeError("Type %s does not define a `name` property" %
|
|
||||||
tensor_type.__name__)
|
|
||||||
if not (hasattr(tensor_type, "dtype") and
|
|
||||||
isinstance(tensor_type.dtype, property)):
|
|
||||||
raise TypeError("Type %s does not define a `dtype` property" %
|
|
||||||
tensor_type.__name__)
|
|
||||||
if not (hasattr(tensor_type, "shape") and
|
|
||||||
isinstance(tensor_type.shape, property)):
|
|
||||||
raise TypeError("Type %s does not define a `shape` property" %
|
|
||||||
tensor_type.__name__)
|
|
||||||
# We expect this list to be small, so choose quadratic complexity
|
|
||||||
# for registration, so that we have a tuple that can be used for
|
|
||||||
# more efficient `isinstance` checks later.
|
|
||||||
global _TENSOR_LIKE_TYPES
|
|
||||||
_TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
|
|
||||||
|
|
||||||
|
|
||||||
def uid():
|
def uid():
|
||||||
@ -304,7 +263,7 @@ def disable_tensor_equality():
|
|||||||
|
|
||||||
# TODO(mdan): This object should subclass Symbol, not just Tensor.
|
# TODO(mdan): This object should subclass Symbol, not just Tensor.
|
||||||
@tf_export("Tensor")
|
@tf_export("Tensor")
|
||||||
class Tensor(internal.NativeObject):
|
class Tensor(internal.NativeObject, core_tf_types.Tensor):
|
||||||
"""A tensor is a multidimensional array of elements represented by a
|
"""A tensor is a multidimensional array of elements represented by a
|
||||||
|
|
||||||
`tf.Tensor` object. All elements are of a single known data type.
|
`tf.Tensor` object. All elements are of a single known data type.
|
||||||
@ -1305,9 +1264,6 @@ class _EagerTensorBase(Tensor):
|
|||||||
EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||||
|
|
||||||
|
|
||||||
register_dense_tensor_like_type(Tensor)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["convert_to_tensor"])
|
@tf_export(v1=["convert_to_tensor"])
|
||||||
def convert_to_tensor_v1(value,
|
def convert_to_tensor_v1(value,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
|
@ -3268,56 +3268,6 @@ class DeprecatedTest(test_util.TensorFlowTestCase):
|
|||||||
test_ops.old()
|
test_ops.old()
|
||||||
|
|
||||||
|
|
||||||
class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
@test_util.disable_tfrt("Graph is not supported yet.")
|
|
||||||
def testSuccess(self):
|
|
||||||
op = ops.Operation(
|
|
||||||
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
|
|
||||||
t = op.outputs[0]
|
|
||||||
self.assertTrue(ops.is_dense_tensor_like(t))
|
|
||||||
|
|
||||||
v = variables.Variable([17])
|
|
||||||
self.assertTrue(ops.is_dense_tensor_like(v))
|
|
||||||
|
|
||||||
class BadClassNoName(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class BadClassBadName(object):
|
|
||||||
|
|
||||||
def name(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class BadClassNoDtype(object):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class BadClassBadDtype(object):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def dtype(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def testBadClass(self):
|
|
||||||
with self.assertRaisesRegexp(TypeError, "`name`"):
|
|
||||||
ops.register_dense_tensor_like_type(
|
|
||||||
DenseTensorLikeTypeTest.BadClassNoName)
|
|
||||||
with self.assertRaisesRegexp(TypeError, "`name`"):
|
|
||||||
ops.register_dense_tensor_like_type(
|
|
||||||
DenseTensorLikeTypeTest.BadClassBadName)
|
|
||||||
with self.assertRaisesRegexp(TypeError, "`dtype`"):
|
|
||||||
ops.register_dense_tensor_like_type(
|
|
||||||
DenseTensorLikeTypeTest.BadClassNoDtype)
|
|
||||||
with self.assertRaisesRegexp(TypeError, "`dtype`"):
|
|
||||||
ops.register_dense_tensor_like_type(
|
|
||||||
DenseTensorLikeTypeTest.BadClassBadDtype)
|
|
||||||
|
|
||||||
|
|
||||||
class NameScopeTest(test_util.TensorFlowTestCase):
|
class NameScopeTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testStripAndPrependScope(self):
|
def testStripAndPrependScope(self):
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.types import core
|
||||||
from tensorflow.python.types import internal
|
from tensorflow.python.types import internal
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -1009,7 +1010,7 @@ def is_tensor(x): # pylint: disable=invalid-name
|
|||||||
`True` if `x` is a tensor or "tensor-like", `False` if not.
|
`True` if `x` is a tensor or "tensor-like", `False` if not.
|
||||||
"""
|
"""
|
||||||
return (isinstance(x, internal.NativeObject) or
|
return (isinstance(x, internal.NativeObject) or
|
||||||
ops.is_dense_tensor_like(x) or
|
isinstance(x, core.Tensor) or
|
||||||
getattr(x, "is_tensor_like", False))
|
getattr(x, "is_tensor_like", False))
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ from tensorflow.python.ops.losses import util as tf_losses_utils
|
|||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||||
|
from tensorflow.python.types import core
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
@ -3143,7 +3144,7 @@ def _convert_scipy_sparse_tensor(value, expected_input):
|
|||||||
The possibly-converted 'value'.
|
The possibly-converted 'value'.
|
||||||
"""
|
"""
|
||||||
if issparse is not None and issparse(value):
|
if issparse is not None and issparse(value):
|
||||||
if ops.is_dense_tensor_like(expected_input):
|
if isinstance(expected_input, core.Tensor):
|
||||||
if ops.executing_eagerly_outside_functions():
|
if ops.executing_eagerly_outside_functions():
|
||||||
# In TF2 we do not silently densify sparse matrices.
|
# In TF2 we do not silently densify sparse matrices.
|
||||||
raise ValueError('A SciPy sparse matrix was passed to a model '
|
raise ValueError('A SciPy sparse matrix was passed to a model '
|
||||||
|
@ -23,9 +23,10 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.types import core
|
||||||
|
|
||||||
|
|
||||||
class AutoCastVariable(variables.Variable):
|
class AutoCastVariable(variables.Variable, core.Tensor):
|
||||||
"""Variable that will cast itself to a different dtype in applicable contexts.
|
"""Variable that will cast itself to a different dtype in applicable contexts.
|
||||||
|
|
||||||
This class wraps a floating-point `tf.Variable`. It emulates the variable
|
This class wraps a floating-point `tf.Variable`. It emulates the variable
|
||||||
@ -417,7 +418,6 @@ class AutoCastVariable(variables.Variable):
|
|||||||
|
|
||||||
ops.register_tensor_conversion_function(AutoCastVariable,
|
ops.register_tensor_conversion_function(AutoCastVariable,
|
||||||
AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access
|
AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access
|
||||||
ops.register_dense_tensor_like_type(AutoCastVariable)
|
|
||||||
|
|
||||||
|
|
||||||
def create_autocast_variable(variable):
|
def create_autocast_variable(variable):
|
||||||
|
@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_math_ops
|
|||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
from tensorflow.python.ops.gen_array_ops import *
|
from tensorflow.python.ops.gen_array_ops import *
|
||||||
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
|
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.types import core
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import dispatch
|
from tensorflow.python.util import dispatch
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -1381,13 +1382,13 @@ def _autopacking_helper(list_or_tuple, dtype, name):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
# NOTE: Fast path when all the items are tensors, this doesn't do any type
|
# NOTE: Fast path when all the items are tensors, this doesn't do any type
|
||||||
# checking.
|
# checking.
|
||||||
if all(ops.is_dense_tensor_like(elem) for elem in list_or_tuple):
|
if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
|
||||||
return gen_array_ops.pack(list_or_tuple, name=name)
|
return gen_array_ops.pack(list_or_tuple, name=name)
|
||||||
must_pack = False
|
must_pack = False
|
||||||
converted_elems = []
|
converted_elems = []
|
||||||
with ops.name_scope(name) as scope:
|
with ops.name_scope(name) as scope:
|
||||||
for i, elem in enumerate(list_or_tuple):
|
for i, elem in enumerate(list_or_tuple):
|
||||||
if ops.is_dense_tensor_like(elem):
|
if isinstance(elem, core.Tensor):
|
||||||
if dtype is not None and elem.dtype.base_dtype != dtype:
|
if dtype is not None and elem.dtype.base_dtype != dtype:
|
||||||
raise TypeError("Cannot convert a list containing a tensor of dtype "
|
raise TypeError("Cannot convert a list containing a tensor of dtype "
|
||||||
"%s to %s (Tensor is: %r)" %
|
"%s to %s (Tensor is: %r)" %
|
||||||
@ -1396,7 +1397,7 @@ def _autopacking_helper(list_or_tuple, dtype, name):
|
|||||||
must_pack = True
|
must_pack = True
|
||||||
elif isinstance(elem, (list, tuple)):
|
elif isinstance(elem, (list, tuple)):
|
||||||
converted_elem = _autopacking_helper(elem, dtype, str(i))
|
converted_elem = _autopacking_helper(elem, dtype, str(i))
|
||||||
if ops.is_dense_tensor_like(converted_elem):
|
if isinstance(converted_elem, core.Tensor):
|
||||||
must_pack = True
|
must_pack = True
|
||||||
converted_elems.append(converted_elem)
|
converted_elems.append(converted_elem)
|
||||||
else:
|
else:
|
||||||
@ -1404,7 +1405,7 @@ def _autopacking_helper(list_or_tuple, dtype, name):
|
|||||||
if must_pack:
|
if must_pack:
|
||||||
elems_as_tensors = []
|
elems_as_tensors = []
|
||||||
for i, elem in enumerate(converted_elems):
|
for i, elem in enumerate(converted_elems):
|
||||||
if ops.is_dense_tensor_like(elem):
|
if isinstance(elem, core.Tensor):
|
||||||
elems_as_tensors.append(elem)
|
elems_as_tensors.append(elem)
|
||||||
else:
|
else:
|
||||||
# NOTE(mrry): This is inefficient, but it enables us to
|
# NOTE(mrry): This is inefficient, but it enables us to
|
||||||
@ -1429,7 +1430,7 @@ def _get_dtype_from_nested_lists(list_or_tuple):
|
|||||||
such object exists.
|
such object exists.
|
||||||
"""
|
"""
|
||||||
for elem in list_or_tuple:
|
for elem in list_or_tuple:
|
||||||
if ops.is_dense_tensor_like(elem):
|
if isinstance(elem, core.Tensor):
|
||||||
return elem.dtype.base_dtype
|
return elem.dtype.base_dtype
|
||||||
elif isinstance(elem, (list, tuple)):
|
elif isinstance(elem, (list, tuple)):
|
||||||
maybe_dtype = _get_dtype_from_nested_lists(elem)
|
maybe_dtype = _get_dtype_from_nested_lists(elem)
|
||||||
@ -1441,7 +1442,7 @@ def _get_dtype_from_nested_lists(list_or_tuple):
|
|||||||
def _cast_nested_seqs_to_dtype(dtype):
|
def _cast_nested_seqs_to_dtype(dtype):
|
||||||
|
|
||||||
def _maybe_cast(elem):
|
def _maybe_cast(elem):
|
||||||
if ops.is_dense_tensor_like(elem):
|
if isinstance(elem, core.Tensor):
|
||||||
if dtype != elem.dtype.base_dtype:
|
if dtype != elem.dtype.base_dtype:
|
||||||
elem = gen_math_ops.cast(elem, dtype)
|
elem = gen_math_ops.cast(elem, dtype)
|
||||||
return elem
|
return elem
|
||||||
@ -1455,7 +1456,7 @@ _NON_AUTOPACKABLE_TYPES.add(np.ndarray)
|
|||||||
|
|
||||||
def _should_not_autopack(v):
|
def _should_not_autopack(v):
|
||||||
# The condition we really want is
|
# The condition we really want is
|
||||||
# ops.is_dense_tensor_like(...)
|
# any(isinstance(elem, core.Tensor))
|
||||||
# but it is >5x slower due to abc.ABCMeta.__instancecheck__.
|
# but it is >5x slower due to abc.ABCMeta.__instancecheck__.
|
||||||
# pylint: disable=unidiomatic-typecheck
|
# pylint: disable=unidiomatic-typecheck
|
||||||
# TODO(slebedev): add nest.all?
|
# TODO(slebedev): add nest.all?
|
||||||
|
@ -49,6 +49,7 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.ops.gen_resource_variable_ops import *
|
from tensorflow.python.ops.gen_resource_variable_ops import *
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
|
from tensorflow.python.types import core
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util.deprecation import deprecated
|
from tensorflow.python.util.deprecation import deprecated
|
||||||
|
|
||||||
@ -330,7 +331,7 @@ def variable_accessed(variable):
|
|||||||
tape.variable_accessed(variable)
|
tape.variable_accessed(variable)
|
||||||
|
|
||||||
|
|
||||||
class BaseResourceVariable(variables.VariableV1):
|
class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||||
"""A python variable from an existing handle."""
|
"""A python variable from an existing handle."""
|
||||||
|
|
||||||
# TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
|
# TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
|
||||||
@ -1830,7 +1831,6 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
|
|||||||
# allowing instances of the class to be used as tensors.
|
# allowing instances of the class to be used as tensors.
|
||||||
ops.register_tensor_conversion_function(BaseResourceVariable,
|
ops.register_tensor_conversion_function(BaseResourceVariable,
|
||||||
_dense_var_to_tensor)
|
_dense_var_to_tensor)
|
||||||
ops.register_dense_tensor_like_type(BaseResourceVariable)
|
|
||||||
|
|
||||||
|
|
||||||
class _UnreadVariable(BaseResourceVariable):
|
class _UnreadVariable(BaseResourceVariable):
|
||||||
@ -1955,9 +1955,6 @@ class _UnreadVariable(BaseResourceVariable):
|
|||||||
return self._parent_op
|
return self._parent_op
|
||||||
|
|
||||||
|
|
||||||
ops.register_dense_tensor_like_type(_UnreadVariable)
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("ReadVariableOp")
|
@ops.RegisterGradient("ReadVariableOp")
|
||||||
def _ReadGrad(_, grad):
|
def _ReadGrad(_, grad):
|
||||||
"""Gradient for read op."""
|
"""Gradient for read op."""
|
||||||
|
@ -42,6 +42,7 @@ from tensorflow.python.ops import init_ops
|
|||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.types import core
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import function_utils
|
from tensorflow.python.util import function_utils
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
@ -1000,7 +1001,7 @@ class _VariableStore(object):
|
|||||||
return initializer, initializing_from_value
|
return initializer, initializing_from_value
|
||||||
|
|
||||||
|
|
||||||
class _LazyEvalTensor(object):
|
class _LazyEvalTensor(core.Tensor):
|
||||||
"""A Tensor-like object that only evaluates its thunk when used."""
|
"""A Tensor-like object that only evaluates its thunk when used."""
|
||||||
|
|
||||||
def __init__(self, thunk):
|
def __init__(self, thunk):
|
||||||
@ -1069,8 +1070,6 @@ session.register_session_run_conversion_functions(
|
|||||||
lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access
|
lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
|
|
||||||
ops.register_dense_tensor_like_type(_LazyEvalTensor)
|
|
||||||
|
|
||||||
|
|
||||||
# To stop regularization, use this regularizer
|
# To stop regularization, use this regularizer
|
||||||
@tf_export(v1=["no_regularizer"])
|
@tf_export(v1=["no_regularizer"])
|
||||||
|
@ -47,6 +47,7 @@ from tensorflow.python.util import tf_should_use
|
|||||||
from tensorflow.python.util.deprecation import deprecated
|
from tensorflow.python.util.deprecation import deprecated
|
||||||
from tensorflow.python.util.deprecation import deprecated_args
|
from tensorflow.python.util.deprecation import deprecated_args
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
from tensorflow.python.types import core
|
||||||
|
|
||||||
|
|
||||||
def default_variable_creator(_, **kwds):
|
def default_variable_creator(_, **kwds):
|
||||||
@ -264,6 +265,7 @@ class VariableMetaclass(type):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("Variable", v1=[])
|
@tf_export("Variable", v1=[])
|
||||||
|
# TODO(mdan): This should subclass core.Tensor, and not all its subclasses?
|
||||||
class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
|
||||||
"""See the [variable guide](https://tensorflow.org/guide/variable).
|
"""See the [variable guide](https://tensorflow.org/guide/variable).
|
||||||
|
|
||||||
@ -1551,7 +1553,7 @@ class VariableV1(Variable):
|
|||||||
|
|
||||||
|
|
||||||
# TODO(apassos): do not repeat all comments here
|
# TODO(apassos): do not repeat all comments here
|
||||||
class RefVariable(VariableV1):
|
class RefVariable(VariableV1, core.Tensor):
|
||||||
"""Ref-based implementation of variables."""
|
"""Ref-based implementation of variables."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -3032,7 +3034,6 @@ class PartitionedVariable(object):
|
|||||||
# allowing instances of the class to be used as tensors.
|
# allowing instances of the class to be used as tensors.
|
||||||
ops.register_tensor_conversion_function(RefVariable,
|
ops.register_tensor_conversion_function(RefVariable,
|
||||||
RefVariable._TensorConversionFunction) # pylint: disable=protected-access
|
RefVariable._TensorConversionFunction) # pylint: disable=protected-access
|
||||||
ops.register_dense_tensor_like_type(RefVariable)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["global_variables"])
|
@tf_export(v1=["global_variables"])
|
||||||
|
@ -226,6 +226,7 @@ py_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/profiler/internal:_pywrap_traceme",
|
"//tensorflow/python/profiler/internal:_pywrap_traceme",
|
||||||
|
"//tensorflow/python/types",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -27,6 +27,9 @@ py_strict_library(
|
|||||||
"internal.py",
|
"internal.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = [
|
||||||
|
"//tensorflow:__subpackages__",
|
||||||
|
"//tensorflow:types_whitelist",
|
||||||
|
],
|
||||||
deps = [],
|
deps = [],
|
||||||
)
|
)
|
||||||
|
@ -2,6 +2,7 @@ path: "tensorflow.Tensor"
|
|||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
|
||||||
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
|
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "OVERLOADABLE_OPERATORS"
|
name: "OVERLOADABLE_OPERATORS"
|
||||||
|
@ -2,6 +2,7 @@ path: "tensorflow.Tensor"
|
|||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
|
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
|
||||||
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
|
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "OVERLOADABLE_OPERATORS"
|
name: "OVERLOADABLE_OPERATORS"
|
||||||
|
Loading…
Reference in New Issue
Block a user