From de9ba3670d7a6b02bfc39a35ec6434917ed8423a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 28 Jun 2019 05:29:35 -0700 Subject: [PATCH] Generalized *Tensor fast-path in internal_convert_to_tensor Note that Tensor no longer has a conversion function because it never reaches the slow path iterating over the registered converters. PiperOrigin-RevId: 255584682 --- tensorflow/python/framework/ops.py | 39 ++++++++++++------------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 72bff2d3ade..c716f08f1fe 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1030,15 +1030,6 @@ class _EagerTensorBase(Tensor): EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase) -def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False): - _ = name, as_ref - if dtype and not dtype.is_compatible_with(t.dtype): - raise ValueError( - "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % - (dtype.name, t.dtype.name, str(t))) - return t -tensor_conversion_registry.register_tensor_conversion_function( - Tensor, _TensorTensorConversionFunction, 0) register_dense_tensor_like_type(Tensor) @@ -1172,23 +1163,23 @@ def internal_convert_to_tensor(value, ctx=None, accept_composite_tensors=False): """Implementation of the public convert_to_tensor.""" - if ctx is None: - ctx = context.context() - if isinstance(value, EagerTensor): - if ctx.executing_eagerly(): - if dtype is not None: - dtype = dtypes.as_dtype(dtype) - value = _TensorTensorConversionFunction(value, dtype=dtype) - return value - else: - graph = get_default_graph() - if not graph.building_function: - raise RuntimeError("Attempting to capture an EagerTensor without " - "building a function.") - return graph.capture(value, name=name) - if dtype is not None: dtype = dtypes.as_dtype(dtype) + if ctx is None: + ctx = context.context() + if isinstance(value, EagerTensor) and not ctx.executing_eagerly(): + graph = get_default_graph() + if not graph.building_function: + raise RuntimeError("Attempting to capture an EagerTensor without " + "building a function.") + return graph.capture(value, name=name) + elif isinstance(value, Tensor): + if dtype is not None and not dtype.is_compatible_with(value.dtype): + raise ValueError( + "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % + (dtype.name, value.dtype.name, value)) + return value + unwrapped_type = type(value) conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None) if conversion_func_list is None: