diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f07bca17061..7e51d3a330d 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1506,6 +1506,13 @@ def convert_to_tensor(value, if preferred_dtype is not None: preferred_dtype = dtypes.as_dtype(preferred_dtype) + + # See below for the reason why it's `type(value)` and not just `value`. + # https://docs.python.org/3.8/reference/datamodel.html#special-lookup + overload = getattr(type(value), "__tf_tensor__", None) + if overload is not None: + return overload(value, dtype, name) + for base_type, conversion_func in tensor_conversion_registry.get(type(value)): # If dtype is None but preferred_dtype is not None, we try to # cast to preferred_dtype first. @@ -2333,6 +2340,10 @@ class Operation(object): def __repr__(self): return "<tf.Operation '%s' type=%s>" % (self.name, self.type) + def __tf_tensor__(self, dtype=None, name=None): + """Raises a helpful error.""" + raise TypeError("can't convert Operation '{}' to Tensor".format(self.name)) + @property def outputs(self): """The list of `Tensor` objects representing the outputs of this op.""" @@ -6833,13 +6844,6 @@ def get_from_proto_function(collection_name): return None -def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): - """Produce a nice error if someone converts an Operation to a Tensor.""" - raise TypeError(("Can't convert Operation '%s' to Tensor " - "(target dtype=%r, name=%r, as_ref=%r)") % - (op.name, dtype, name, as_ref)) - - def _op_to_colocate_with(v, graph): """Operation object corresponding to v to use for colocation constraints.""" if v is None: @@ -6873,10 +6877,6 @@ def _is_keras_symbolic_tensor(x): return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph" -tensor_conversion_registry.register_tensor_conversion_function( - Operation, _operation_conversion_error) - - # These symbols were originally defined in this module; import them for # backwards compatibility until all references have been updated to access # them from the indexed_slices.py module. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 4129b55e3fd..58e3f650c44 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -858,12 +858,25 @@ class OperationTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): ops.convert_to_tensor(tensor, dtype=dtypes.int32) + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorProtocol(self): + class TensorCompatible: + + def __tf_tensor__(self, dtype=None, name=None): + return constant_op.constant((1, 2, 3), dtype=dtype, name=name) + + tc = TensorCompatible() + + tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32) + self.assertEqual(tensor.dtype, dtypes.int32) + self.assertAllEqual((1, 2, 3), self.evaluate(tensor)) + @test_util.run_deprecated_v1 def testNoConvert(self): # Operation cannot be converted to Tensor. op = control_flow_ops.no_op() with self.assertRaisesRegex(TypeError, - r"Can't convert Operation '.*' to Tensor"): + "can't convert Operation '.+' to Tensor"): ops.convert_to_tensor(op) def testStr(self): diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD index 5f3f4fd0e31..d48f066d294 100644 --- a/tensorflow/python/types/BUILD +++ b/tensorflow/python/types/BUILD @@ -35,6 +35,7 @@ py_strict_library( ":doc_typealias", "//tensorflow/python:tf_export", "//third_party/py/numpy", + "@typing_extensions_archive//:typing_extensions", ], ) diff --git a/tensorflow/python/types/core.py b/tensorflow/python/types/core.py index bec5aecaba0..b4506594a82 100644 --- a/tensorflow/python/types/core.py +++ b/tensorflow/python/types/core.py @@ -18,14 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys import textwrap from typing import Union + import numpy as np from tensorflow.python.types import doc_typealias from tensorflow.python.util.tf_export import tf_export +if sys.version_info >= (3, 8): + from typing import Protocol # pylint:disable=g-import-not-at-top +else: + from typing_extensions import Protocol # pylint:disable=g-import-not-at-top + # TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced. # TODO(mdan): Add type annotations. @@ -67,9 +74,24 @@ class Value(Tensor): pass +class TensorProtocol(Protocol): + """Protocol type for objects that can be converted to Tensor.""" + + def __tf_tensor__(self, dtype=None, name=None): + """Converts this object to a Tensor. + + Args: + dtype: data type for the returned Tensor + name: a name for the operations which create the Tensor + Returns: + A Tensor. + """ + pass + + # TODO(rahulkamat): Add missing types that are convertible to Tensor. -TensorLike = Union[Tensor, int, float, bool, str, complex, tuple, list, - np.ndarray] +TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, complex, + tuple, list, np.ndarray] doc_typealias.document( obj=TensorLike, doc=textwrap.dedent("""\