Generalized *Tensor fast-path in internal_convert_to_tensor

Note that Tensor no longer has a conversion function because it never
reaches the slow path iterating over the registered converters.

PiperOrigin-RevId: 255584682
This commit is contained in:
Sergei Lebedev 2019-06-28 05:29:35 -07:00 committed by TensorFlower Gardener
parent d97ccf68c9
commit de9ba3670d

View File

@ -1030,15 +1030,6 @@ class _EagerTensorBase(Tensor):
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
_ = name, as_ref
if dtype and not dtype.is_compatible_with(t.dtype):
raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
(dtype.name, t.dtype.name, str(t)))
return t
tensor_conversion_registry.register_tensor_conversion_function(
Tensor, _TensorTensorConversionFunction, 0)
register_dense_tensor_like_type(Tensor)
@ -1172,23 +1163,23 @@ def internal_convert_to_tensor(value,
ctx=None,
accept_composite_tensors=False):
"""Implementation of the public convert_to_tensor."""
if ctx is None:
ctx = context.context()
if isinstance(value, EagerTensor):
if ctx.executing_eagerly():
if dtype is not None:
dtype = dtypes.as_dtype(dtype)
value = _TensorTensorConversionFunction(value, dtype=dtype)
return value
else:
if ctx is None:
ctx = context.context()
if isinstance(value, EagerTensor) and not ctx.executing_eagerly():
graph = get_default_graph()
if not graph.building_function:
raise RuntimeError("Attempting to capture an EagerTensor without "
"building a function.")
return graph.capture(value, name=name)
elif isinstance(value, Tensor):
if dtype is not None and not dtype.is_compatible_with(value.dtype):
raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
(dtype.name, value.dtype.name, value))
return value
if dtype is not None:
dtype = dtypes.as_dtype(dtype)
unwrapped_type = type(value)
conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
if conversion_func_list is None: