Add support for the __tf_tensor__ protocol to convert_to_tensor (see RFC: TensorFlow Canonical Type System). Use it instead of the conversion registry for Operation objects.

PiperOrigin-RevId: 328546755
Change-Id: I45ca8277c11631ac86473e074f00364eaad70796
This commit is contained in:
Dan Moldovan 2020-08-26 09:29:54 -07:00 committed by TensorFlower Gardener
parent fe36fba52b
commit 40a314cab8
4 changed files with 50 additions and 14 deletions

View File

@ -1506,6 +1506,13 @@ def convert_to_tensor(value,
if preferred_dtype is not None:
preferred_dtype = dtypes.as_dtype(preferred_dtype)
# See below for the reason why it's `type(value)` and not just `value`.
# https://docs.python.org/3.8/reference/datamodel.html#special-lookup
overload = getattr(type(value), "__tf_tensor__", None)
if overload is not None:
return overload(value, dtype, name)
for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
# If dtype is None but preferred_dtype is not None, we try to
# cast to preferred_dtype first.
@ -2333,6 +2340,10 @@ class Operation(object):
def __repr__(self):
return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
def __tf_tensor__(self, dtype=None, name=None):
"""Raises a helpful error."""
raise TypeError("can't convert Operation '{}' to Tensor".format(self.name))
@property
def outputs(self):
"""The list of `Tensor` objects representing the outputs of this op."""
@ -6833,13 +6844,6 @@ def get_from_proto_function(collection_name):
return None
def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
"""Produce a nice error if someone converts an Operation to a Tensor."""
raise TypeError(("Can't convert Operation '%s' to Tensor "
"(target dtype=%r, name=%r, as_ref=%r)") %
(op.name, dtype, name, as_ref))
def _op_to_colocate_with(v, graph):
"""Operation object corresponding to v to use for colocation constraints."""
if v is None:
@ -6873,10 +6877,6 @@ def _is_keras_symbolic_tensor(x):
return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
tensor_conversion_registry.register_tensor_conversion_function(
Operation, _operation_conversion_error)
# These symbols were originally defined in this module; import them for
# backwards compatibility until all references have been updated to access
# them from the indexed_slices.py module.

View File

@ -858,12 +858,25 @@ class OperationTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
ops.convert_to_tensor(tensor, dtype=dtypes.int32)
@test_util.run_in_graph_and_eager_modes
def testConvertToTensorProtocol(self):
class TensorCompatible:
def __tf_tensor__(self, dtype=None, name=None):
return constant_op.constant((1, 2, 3), dtype=dtype, name=name)
tc = TensorCompatible()
tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32)
self.assertEqual(tensor.dtype, dtypes.int32)
self.assertAllEqual((1, 2, 3), self.evaluate(tensor))
@test_util.run_deprecated_v1
def testNoConvert(self):
# Operation cannot be converted to Tensor.
op = control_flow_ops.no_op()
with self.assertRaisesRegex(TypeError,
r"Can't convert Operation '.*' to Tensor"):
"can't convert Operation '.+' to Tensor"):
ops.convert_to_tensor(op)
def testStr(self):

View File

@ -35,6 +35,7 @@ py_strict_library(
":doc_typealias",
"//tensorflow/python:tf_export",
"//third_party/py/numpy",
"@typing_extensions_archive//:typing_extensions",
],
)

View File

@ -18,14 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import textwrap
from typing import Union
import numpy as np
from tensorflow.python.types import doc_typealias
from tensorflow.python.util.tf_export import tf_export
if sys.version_info >= (3, 8):
from typing import Protocol # pylint:disable=g-import-not-at-top
else:
from typing_extensions import Protocol # pylint:disable=g-import-not-at-top
# TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced.
# TODO(mdan): Add type annotations.
@ -67,9 +74,24 @@ class Value(Tensor):
pass
class TensorProtocol(Protocol):
"""Protocol type for objects that can be converted to Tensor."""
def __tf_tensor__(self, dtype=None, name=None):
"""Converts this object to a Tensor.
Args:
dtype: data type for the returned Tensor
name: a name for the operations which create the Tensor
Returns:
A Tensor.
"""
pass
# TODO(rahulkamat): Add missing types that are convertible to Tensor.
TensorLike = Union[Tensor, int, float, bool, str, complex, tuple, list,
np.ndarray]
TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, complex,
tuple, list, np.ndarray]
doc_typealias.document(
obj=TensorLike,
doc=textwrap.dedent("""\