Generalized *Tensor fast-path in internal_convert_to_tensor
Note that Tensor no longer has a conversion function because it never reaches the slow path iterating over the registered converters. PiperOrigin-RevId: 255584682
This commit is contained in:
parent
d97ccf68c9
commit
de9ba3670d
@ -1030,15 +1030,6 @@ class _EagerTensorBase(Tensor):
|
||||
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||
|
||||
|
||||
def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
|
||||
_ = name, as_ref
|
||||
if dtype and not dtype.is_compatible_with(t.dtype):
|
||||
raise ValueError(
|
||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
|
||||
(dtype.name, t.dtype.name, str(t)))
|
||||
return t
|
||||
tensor_conversion_registry.register_tensor_conversion_function(
|
||||
Tensor, _TensorTensorConversionFunction, 0)
|
||||
register_dense_tensor_like_type(Tensor)
|
||||
|
||||
|
||||
@ -1172,23 +1163,23 @@ def internal_convert_to_tensor(value,
|
||||
ctx=None,
|
||||
accept_composite_tensors=False):
|
||||
"""Implementation of the public convert_to_tensor."""
|
||||
if ctx is None:
|
||||
ctx = context.context()
|
||||
if isinstance(value, EagerTensor):
|
||||
if ctx.executing_eagerly():
|
||||
if dtype is not None:
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
value = _TensorTensorConversionFunction(value, dtype=dtype)
|
||||
return value
|
||||
else:
|
||||
if ctx is None:
|
||||
ctx = context.context()
|
||||
if isinstance(value, EagerTensor) and not ctx.executing_eagerly():
|
||||
graph = get_default_graph()
|
||||
if not graph.building_function:
|
||||
raise RuntimeError("Attempting to capture an EagerTensor without "
|
||||
"building a function.")
|
||||
return graph.capture(value, name=name)
|
||||
elif isinstance(value, Tensor):
|
||||
if dtype is not None and not dtype.is_compatible_with(value.dtype):
|
||||
raise ValueError(
|
||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
|
||||
(dtype.name, value.dtype.name, value))
|
||||
return value
|
||||
|
||||
if dtype is not None:
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
unwrapped_type = type(value)
|
||||
conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
|
||||
if conversion_func_list is None:
|
||||
|
Loading…
Reference in New Issue
Block a user