Make core.Tensor the base type for Tensor and replace the register_dense_tensor_like with direct subclassing.

PiperOrigin-RevId: 311206817
Change-Id: Id8ae234516d5409d6b70612a99f9f0b3ed53dc7e
This commit is contained in:
Dan Moldovan 2020-05-12 14:50:15 -07:00 committed by TensorFlower Gardener
parent 1e07fa6448
commit c5caa29b5e
17 changed files with 52 additions and 133 deletions

View File

@ -530,6 +530,13 @@ package_group(name = "ndarray_tensor_allow_list")
# TODO(b/154762408) Remove this package group once it's no longer needed. # TODO(b/154762408) Remove this package group once it's no longer needed.
package_group(name = "composite_tensor_whitelist") package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove.
package_group(
name = "types_whitelist",
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
filegroup( filegroup(
name = "intel_binary_blob", name = "intel_binary_blob",
data = if_mkl_ml( data = if_mkl_ml(

View File

@ -230,6 +230,7 @@ py_library(
"//tensorflow/python/tools:module_util", "//tensorflow/python/tools:module_util",
"//tensorflow/python/tools/api/generator:create_python_api", "//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python/tpu:tpu_noestimator", "//tensorflow/python/tpu:tpu_noestimator",
"//tensorflow/python/types",
"//third_party/py/numpy", "//third_party/py/numpy",
], ],
) )

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.types import core
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -422,7 +423,8 @@ class DistributedVarOp(object):
return hash((self.name, self.graph, self.traceback, self.type)) return hash((self.name, self.graph, self.traceback, self.type))
class DistributedVariable(DistributedDelegate, variables_lib.Variable): class DistributedVariable(DistributedDelegate, variables_lib.Variable,
core.Tensor):
"""Holds a map from replica to variables.""" """Holds a map from replica to variables."""
# TODO(josh11b): Support changing the set of variables if e.g. if new # TODO(josh11b): Support changing the set of variables if e.g. if new
@ -741,9 +743,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
pass pass
ops.register_dense_tensor_like_type(DistributedVariable)
def _validate_colocate_extended(v, extended): def _validate_colocate_extended(v, extended):
variable_strategy = v._distribute_strategy # pylint: disable=protected-access variable_strategy = v._distribute_strategy # pylint: disable=protected-access
if variable_strategy.extended is not extended: if variable_strategy.extended is not extended:
@ -1380,7 +1379,7 @@ def value_container(val):
return val return val
class AggregatingVariable(variables_lib.Variable): class AggregatingVariable(variables_lib.Variable, core.Tensor):
"""A wrapper around a variable that aggregates updates across replicas.""" """A wrapper around a variable that aggregates updates across replicas."""
def __init__(self, strategy, v, aggregation): def __init__(self, strategy, v, aggregation):
@ -1649,4 +1648,3 @@ def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False):
ops.register_tensor_conversion_function(AggregatingVariable, ops.register_tensor_conversion_function(AggregatingVariable,
_tensor_conversion_aggregate) _tensor_conversion_aggregate)
ops.register_dense_tensor_like_type(AggregatingVariable)

View File

@ -56,6 +56,7 @@ from tensorflow.python.saved_model.model_utils import mode_keys
from tensorflow.python.tpu import tpu_strategy_util from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.tracking import util as trackable_utils from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.types import core
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -623,10 +624,10 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
v = variables_lib.Variable( v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation) 0., synchronization=synchronization, aggregation=aggregation)
# In cross replica context. # In cross replica context.
self.assertTrue(ops.is_dense_tensor_like(v)) self.assertIsInstance(v, core.Tensor)
# In replica context. # In replica context.
distribution.run( distribution.run(
lambda v: self.assertTrue(ops.is_dense_tensor_like(v)), args=(v,)) lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
def testAssignReturnValueIsTensorLike(self, distribution, synchronization, def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
aggregation): aggregation):
@ -645,9 +646,9 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
# values is not allowed when aggregation is SUM. See # values is not allowed when aggregation is SUM. See
# `cross_device_ops.reduce_non_distributed_value`. # `cross_device_ops.reduce_non_distributed_value`.
delta = array_ops.identity(1.) delta = array_ops.identity(1.)
self.assertTrue(ops.is_dense_tensor_like(v.assign(delta))) self.assertIsInstance(v.assign(delta), core.Tensor)
self.assertTrue(ops.is_dense_tensor_like(v.assign_sub(delta))) self.assertIsInstance(v.assign_sub(delta), core.Tensor)
self.assertTrue(ops.is_dense_tensor_like(v.assign_add(delta))) self.assertIsInstance(v.assign_add(delta), core.Tensor)
# In cross replica context we return a PerReplica which is not Tensor like # In cross replica context we return a PerReplica which is not Tensor like
# yet. # yet.

View File

@ -62,6 +62,7 @@ from tensorflow.python.framework import versions
from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import core as core_tf_types
from tensorflow.python.types import internal from tensorflow.python.types import internal
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils from tensorflow.python.util import decorator_utils
@ -213,53 +214,11 @@ def _as_graph_element(obj):
return None return None
_TENSOR_LIKE_TYPES = tuple() # Deprecated - do not use.
# This API to avoid breaking estimator and tensorflow-mesh which depend on this
# internal API. The stub should be safe to use after TF 2.3 is released.
def is_dense_tensor_like(t): def is_dense_tensor_like(t):
"""EXPERIMENTAL: Returns true if `t` implements the tensor interface. return isinstance(t, core_tf_types.Tensor)
See `register_dense_tensor_like_type()` for the current definition of a
"tensor-like type".
Args:
t: An object.
Returns:
True iff `t` is an instance of one of the registered "tensor-like" types.
"""
return isinstance(t, _TENSOR_LIKE_TYPES)
def register_dense_tensor_like_type(tensor_type):
"""EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
A "tensor-like type" can represent a single dense tensor, and implements
the `name`, `dtype` and `shape` properties.
Args:
tensor_type: A type implementing the tensor interface.
Raises:
TypeError: If `tensor_type` does not implement the tensor interface.
"""
if not (hasattr(tensor_type, "name") and
isinstance(tensor_type.name, property)):
raise TypeError("Type %s does not define a `name` property" %
tensor_type.__name__)
if not (hasattr(tensor_type, "dtype") and
isinstance(tensor_type.dtype, property)):
raise TypeError("Type %s does not define a `dtype` property" %
tensor_type.__name__)
if not (hasattr(tensor_type, "shape") and
isinstance(tensor_type.shape, property)):
raise TypeError("Type %s does not define a `shape` property" %
tensor_type.__name__)
# We expect this list to be small, so choose quadratic complexity
# for registration, so that we have a tuple that can be used for
# more efficient `isinstance` checks later.
global _TENSOR_LIKE_TYPES
_TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
def uid(): def uid():
@ -304,7 +263,7 @@ def disable_tensor_equality():
# TODO(mdan): This object should subclass Symbol, not just Tensor. # TODO(mdan): This object should subclass Symbol, not just Tensor.
@tf_export("Tensor") @tf_export("Tensor")
class Tensor(internal.NativeObject): class Tensor(internal.NativeObject, core_tf_types.Tensor):
"""A tensor is a multidimensional array of elements represented by a """A tensor is a multidimensional array of elements represented by a
`tf.Tensor` object. All elements are of a single known data type. `tf.Tensor` object. All elements are of a single known data type.
@ -1305,9 +1264,6 @@ class _EagerTensorBase(Tensor):
EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase) EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)
register_dense_tensor_like_type(Tensor)
@tf_export(v1=["convert_to_tensor"]) @tf_export(v1=["convert_to_tensor"])
def convert_to_tensor_v1(value, def convert_to_tensor_v1(value,
dtype=None, dtype=None,

View File

@ -3268,56 +3268,6 @@ class DeprecatedTest(test_util.TensorFlowTestCase):
test_ops.old() test_ops.old()
class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
@test_util.disable_tfrt("Graph is not supported yet.")
def testSuccess(self):
op = ops.Operation(
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
t = op.outputs[0]
self.assertTrue(ops.is_dense_tensor_like(t))
v = variables.Variable([17])
self.assertTrue(ops.is_dense_tensor_like(v))
class BadClassNoName(object):
pass
class BadClassBadName(object):
def name(self):
pass
class BadClassNoDtype(object):
@property
def name(self):
pass
class BadClassBadDtype(object):
@property
def name(self):
pass
def dtype(self):
pass
def testBadClass(self):
with self.assertRaisesRegexp(TypeError, "`name`"):
ops.register_dense_tensor_like_type(
DenseTensorLikeTypeTest.BadClassNoName)
with self.assertRaisesRegexp(TypeError, "`name`"):
ops.register_dense_tensor_like_type(
DenseTensorLikeTypeTest.BadClassBadName)
with self.assertRaisesRegexp(TypeError, "`dtype`"):
ops.register_dense_tensor_like_type(
DenseTensorLikeTypeTest.BadClassNoDtype)
with self.assertRaisesRegexp(TypeError, "`dtype`"):
ops.register_dense_tensor_like_type(
DenseTensorLikeTypeTest.BadClassBadDtype)
class NameScopeTest(test_util.TensorFlowTestCase): class NameScopeTest(test_util.TensorFlowTestCase):
def testStripAndPrependScope(self): def testStripAndPrependScope(self):

View File

@ -26,6 +26,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.types import core
from tensorflow.python.types import internal from tensorflow.python.types import internal
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -1009,7 +1010,7 @@ def is_tensor(x): # pylint: disable=invalid-name
`True` if `x` is a tensor or "tensor-like", `False` if not. `True` if `x` is a tensor or "tensor-like", `False` if not.
""" """
return (isinstance(x, internal.NativeObject) or return (isinstance(x, internal.NativeObject) or
ops.is_dense_tensor_like(x) or isinstance(x, core.Tensor) or
getattr(x, "is_tensor_like", False)) getattr(x, "is_tensor_like", False))

View File

@ -62,6 +62,7 @@ from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.types import core
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
@ -3143,7 +3144,7 @@ def _convert_scipy_sparse_tensor(value, expected_input):
The possibly-converted 'value'. The possibly-converted 'value'.
""" """
if issparse is not None and issparse(value): if issparse is not None and issparse(value):
if ops.is_dense_tensor_like(expected_input): if isinstance(expected_input, core.Tensor):
if ops.executing_eagerly_outside_functions(): if ops.executing_eagerly_outside_functions():
# In TF2 we do not silently densify sparse matrices. # In TF2 we do not silently densify sparse matrices.
raise ValueError('A SciPy sparse matrix was passed to a model ' raise ValueError('A SciPy sparse matrix was passed to a model '

View File

@ -23,9 +23,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.types import core
class AutoCastVariable(variables.Variable): class AutoCastVariable(variables.Variable, core.Tensor):
"""Variable that will cast itself to a different dtype in applicable contexts. """Variable that will cast itself to a different dtype in applicable contexts.
This class wraps a floating-point `tf.Variable`. It emulates the variable This class wraps a floating-point `tf.Variable`. It emulates the variable
@ -417,7 +418,6 @@ class AutoCastVariable(variables.Variable):
ops.register_tensor_conversion_function(AutoCastVariable, ops.register_tensor_conversion_function(AutoCastVariable,
AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access
ops.register_dense_tensor_like_type(AutoCastVariable)
def create_autocast_variable(variable): def create_autocast_variable(variable):

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_math_ops
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from tensorflow.python.ops.gen_array_ops import * from tensorflow.python.ops.gen_array_ops import *
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
from tensorflow.python.types import core
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch from tensorflow.python.util import dispatch
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -1381,13 +1382,13 @@ def _autopacking_helper(list_or_tuple, dtype, name):
if context.executing_eagerly(): if context.executing_eagerly():
# NOTE: Fast path when all the items are tensors, this doesn't do any type # NOTE: Fast path when all the items are tensors, this doesn't do any type
# checking. # checking.
if all(ops.is_dense_tensor_like(elem) for elem in list_or_tuple): if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
return gen_array_ops.pack(list_or_tuple, name=name) return gen_array_ops.pack(list_or_tuple, name=name)
must_pack = False must_pack = False
converted_elems = [] converted_elems = []
with ops.name_scope(name) as scope: with ops.name_scope(name) as scope:
for i, elem in enumerate(list_or_tuple): for i, elem in enumerate(list_or_tuple):
if ops.is_dense_tensor_like(elem): if isinstance(elem, core.Tensor):
if dtype is not None and elem.dtype.base_dtype != dtype: if dtype is not None and elem.dtype.base_dtype != dtype:
raise TypeError("Cannot convert a list containing a tensor of dtype " raise TypeError("Cannot convert a list containing a tensor of dtype "
"%s to %s (Tensor is: %r)" % "%s to %s (Tensor is: %r)" %
@ -1396,7 +1397,7 @@ def _autopacking_helper(list_or_tuple, dtype, name):
must_pack = True must_pack = True
elif isinstance(elem, (list, tuple)): elif isinstance(elem, (list, tuple)):
converted_elem = _autopacking_helper(elem, dtype, str(i)) converted_elem = _autopacking_helper(elem, dtype, str(i))
if ops.is_dense_tensor_like(converted_elem): if isinstance(converted_elem, core.Tensor):
must_pack = True must_pack = True
converted_elems.append(converted_elem) converted_elems.append(converted_elem)
else: else:
@ -1404,7 +1405,7 @@ def _autopacking_helper(list_or_tuple, dtype, name):
if must_pack: if must_pack:
elems_as_tensors = [] elems_as_tensors = []
for i, elem in enumerate(converted_elems): for i, elem in enumerate(converted_elems):
if ops.is_dense_tensor_like(elem): if isinstance(elem, core.Tensor):
elems_as_tensors.append(elem) elems_as_tensors.append(elem)
else: else:
# NOTE(mrry): This is inefficient, but it enables us to # NOTE(mrry): This is inefficient, but it enables us to
@ -1429,7 +1430,7 @@ def _get_dtype_from_nested_lists(list_or_tuple):
such object exists. such object exists.
""" """
for elem in list_or_tuple: for elem in list_or_tuple:
if ops.is_dense_tensor_like(elem): if isinstance(elem, core.Tensor):
return elem.dtype.base_dtype return elem.dtype.base_dtype
elif isinstance(elem, (list, tuple)): elif isinstance(elem, (list, tuple)):
maybe_dtype = _get_dtype_from_nested_lists(elem) maybe_dtype = _get_dtype_from_nested_lists(elem)
@ -1441,7 +1442,7 @@ def _get_dtype_from_nested_lists(list_or_tuple):
def _cast_nested_seqs_to_dtype(dtype): def _cast_nested_seqs_to_dtype(dtype):
def _maybe_cast(elem): def _maybe_cast(elem):
if ops.is_dense_tensor_like(elem): if isinstance(elem, core.Tensor):
if dtype != elem.dtype.base_dtype: if dtype != elem.dtype.base_dtype:
elem = gen_math_ops.cast(elem, dtype) elem = gen_math_ops.cast(elem, dtype)
return elem return elem
@ -1455,7 +1456,7 @@ _NON_AUTOPACKABLE_TYPES.add(np.ndarray)
def _should_not_autopack(v): def _should_not_autopack(v):
# The condition we really want is # The condition we really want is
# ops.is_dense_tensor_like(...) # any(isinstance(elem, core.Tensor))
# but it is >5x slower due to abc.ABCMeta.__instancecheck__. # but it is >5x slower due to abc.ABCMeta.__instancecheck__.
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
# TODO(slebedev): add nest.all? # TODO(slebedev): add nest.all?

View File

@ -49,6 +49,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops.gen_resource_variable_ops import * from tensorflow.python.ops.gen_resource_variable_ops import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.types import core
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.deprecation import deprecated
@ -330,7 +331,7 @@ def variable_accessed(variable):
tape.variable_accessed(variable) tape.variable_accessed(variable)
class BaseResourceVariable(variables.VariableV1): class BaseResourceVariable(variables.VariableV1, core.Tensor):
"""A python variable from an existing handle.""" """A python variable from an existing handle."""
# TODO(wangpeng): Deprecate `constraint` when callers no long pass it in. # TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
@ -1830,7 +1831,6 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
# allowing instances of the class to be used as tensors. # allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(BaseResourceVariable, ops.register_tensor_conversion_function(BaseResourceVariable,
_dense_var_to_tensor) _dense_var_to_tensor)
ops.register_dense_tensor_like_type(BaseResourceVariable)
class _UnreadVariable(BaseResourceVariable): class _UnreadVariable(BaseResourceVariable):
@ -1955,9 +1955,6 @@ class _UnreadVariable(BaseResourceVariable):
return self._parent_op return self._parent_op
ops.register_dense_tensor_like_type(_UnreadVariable)
@ops.RegisterGradient("ReadVariableOp") @ops.RegisterGradient("ReadVariableOp")
def _ReadGrad(_, grad): def _ReadGrad(_, grad):
"""Gradient for read op.""" """Gradient for read op."""

View File

@ -42,6 +42,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import core
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_contextlib
@ -1000,7 +1001,7 @@ class _VariableStore(object):
return initializer, initializing_from_value return initializer, initializing_from_value
class _LazyEvalTensor(object): class _LazyEvalTensor(core.Tensor):
"""A Tensor-like object that only evaluates its thunk when used.""" """A Tensor-like object that only evaluates its thunk when used."""
def __init__(self, thunk): def __init__(self, thunk):
@ -1069,8 +1070,6 @@ session.register_session_run_conversion_functions(
lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access
) )
ops.register_dense_tensor_like_type(_LazyEvalTensor)
# To stop regularization, use this regularizer # To stop regularization, use this regularizer
@tf_export(v1=["no_regularizer"]) @tf_export(v1=["no_regularizer"])

View File

@ -47,6 +47,7 @@ from tensorflow.python.util import tf_should_use
from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.types import core
def default_variable_creator(_, **kwds): def default_variable_creator(_, **kwds):
@ -264,6 +265,7 @@ class VariableMetaclass(type):
@tf_export("Variable", v1=[]) @tf_export("Variable", v1=[])
# TODO(mdan): This should subclass core.Tensor, and not all its subclasses?
class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)): class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
"""See the [variable guide](https://tensorflow.org/guide/variable). """See the [variable guide](https://tensorflow.org/guide/variable).
@ -1551,7 +1553,7 @@ class VariableV1(Variable):
# TODO(apassos): do not repeat all comments here # TODO(apassos): do not repeat all comments here
class RefVariable(VariableV1): class RefVariable(VariableV1, core.Tensor):
"""Ref-based implementation of variables.""" """Ref-based implementation of variables."""
def __init__( def __init__(
@ -3032,7 +3034,6 @@ class PartitionedVariable(object):
# allowing instances of the class to be used as tensors. # allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(RefVariable, ops.register_tensor_conversion_function(RefVariable,
RefVariable._TensorConversionFunction) # pylint: disable=protected-access RefVariable._TensorConversionFunction) # pylint: disable=protected-access
ops.register_dense_tensor_like_type(RefVariable)
@tf_export(v1=["global_variables"]) @tf_export(v1=["global_variables"])

View File

@ -226,6 +226,7 @@ py_library(
deps = [ deps = [
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python/profiler/internal:_pywrap_traceme", "//tensorflow/python/profiler/internal:_pywrap_traceme",
"//tensorflow/python/types",
"@six_archive//:six", "@six_archive//:six",
], ],
) )

View File

@ -27,6 +27,9 @@ py_strict_library(
"internal.py", "internal.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"], visibility = [
"//tensorflow:__subpackages__",
"//tensorflow:types_whitelist",
],
deps = [], deps = [],
) )

View File

@ -2,6 +2,7 @@ path: "tensorflow.Tensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>" is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>" is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "OVERLOADABLE_OPERATORS" name: "OVERLOADABLE_OPERATORS"

View File

@ -2,6 +2,7 @@ path: "tensorflow.Tensor"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>" is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>" is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {
name: "OVERLOADABLE_OPERATORS" name: "OVERLOADABLE_OPERATORS"