Add support for the __tf_tensor__ protocol to convert_to_tensor (see RFC: TensorFlow Canonical Type System). Use it instead of the conversion registry for Operation objects.
PiperOrigin-RevId: 328546755 Change-Id: I45ca8277c11631ac86473e074f00364eaad70796
This commit is contained in:
parent
fe36fba52b
commit
40a314cab8
@ -1506,6 +1506,13 @@ def convert_to_tensor(value,
|
||||
|
||||
if preferred_dtype is not None:
|
||||
preferred_dtype = dtypes.as_dtype(preferred_dtype)
|
||||
|
||||
# See below for the reason why it's `type(value)` and not just `value`.
|
||||
# https://docs.python.org/3.8/reference/datamodel.html#special-lookup
|
||||
overload = getattr(type(value), "__tf_tensor__", None)
|
||||
if overload is not None:
|
||||
return overload(value, dtype, name)
|
||||
|
||||
for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
|
||||
# If dtype is None but preferred_dtype is not None, we try to
|
||||
# cast to preferred_dtype first.
|
||||
@ -2333,6 +2340,10 @@ class Operation(object):
|
||||
def __repr__(self):
|
||||
return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
|
||||
|
||||
def __tf_tensor__(self, dtype=None, name=None):
|
||||
"""Raises a helpful error."""
|
||||
raise TypeError("can't convert Operation '{}' to Tensor".format(self.name))
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
"""The list of `Tensor` objects representing the outputs of this op."""
|
||||
@ -6833,13 +6844,6 @@ def get_from_proto_function(collection_name):
|
||||
return None
|
||||
|
||||
|
||||
def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
|
||||
"""Produce a nice error if someone converts an Operation to a Tensor."""
|
||||
raise TypeError(("Can't convert Operation '%s' to Tensor "
|
||||
"(target dtype=%r, name=%r, as_ref=%r)") %
|
||||
(op.name, dtype, name, as_ref))
|
||||
|
||||
|
||||
def _op_to_colocate_with(v, graph):
|
||||
"""Operation object corresponding to v to use for colocation constraints."""
|
||||
if v is None:
|
||||
@ -6873,10 +6877,6 @@ def _is_keras_symbolic_tensor(x):
|
||||
return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
|
||||
|
||||
|
||||
tensor_conversion_registry.register_tensor_conversion_function(
|
||||
Operation, _operation_conversion_error)
|
||||
|
||||
|
||||
# These symbols were originally defined in this module; import them for
|
||||
# backwards compatibility until all references have been updated to access
|
||||
# them from the indexed_slices.py module.
|
||||
|
@ -858,12 +858,25 @@ class OperationTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
ops.convert_to_tensor(tensor, dtype=dtypes.int32)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConvertToTensorProtocol(self):
|
||||
class TensorCompatible:
|
||||
|
||||
def __tf_tensor__(self, dtype=None, name=None):
|
||||
return constant_op.constant((1, 2, 3), dtype=dtype, name=name)
|
||||
|
||||
tc = TensorCompatible()
|
||||
|
||||
tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32)
|
||||
self.assertEqual(tensor.dtype, dtypes.int32)
|
||||
self.assertAllEqual((1, 2, 3), self.evaluate(tensor))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNoConvert(self):
|
||||
# Operation cannot be converted to Tensor.
|
||||
op = control_flow_ops.no_op()
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
r"Can't convert Operation '.*' to Tensor"):
|
||||
"can't convert Operation '.+' to Tensor"):
|
||||
ops.convert_to_tensor(op)
|
||||
|
||||
def testStr(self):
|
||||
|
@ -35,6 +35,7 @@ py_strict_library(
|
||||
":doc_typealias",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//third_party/py/numpy",
|
||||
"@typing_extensions_archive//:typing_extensions",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,14 +18,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.types import doc_typealias
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Protocol # pylint:disable=g-import-not-at-top
|
||||
else:
|
||||
from typing_extensions import Protocol # pylint:disable=g-import-not-at-top
|
||||
|
||||
# TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced.
|
||||
# TODO(mdan): Add type annotations.
|
||||
|
||||
@ -67,9 +74,24 @@ class Value(Tensor):
|
||||
pass
|
||||
|
||||
|
||||
class TensorProtocol(Protocol):
|
||||
"""Protocol type for objects that can be converted to Tensor."""
|
||||
|
||||
def __tf_tensor__(self, dtype=None, name=None):
|
||||
"""Converts this object to a Tensor.
|
||||
|
||||
Args:
|
||||
dtype: data type for the returned Tensor
|
||||
name: a name for the operations which create the Tensor
|
||||
Returns:
|
||||
A Tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# TODO(rahulkamat): Add missing types that are convertible to Tensor.
|
||||
TensorLike = Union[Tensor, int, float, bool, str, complex, tuple, list,
|
||||
np.ndarray]
|
||||
TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, complex,
|
||||
tuple, list, np.ndarray]
|
||||
doc_typealias.document(
|
||||
obj=TensorLike,
|
||||
doc=textwrap.dedent("""\
|
||||
|
Loading…
Reference in New Issue
Block a user