Generalized *Tensor fast-path in internal_convert_to_tensor
Note that Tensor no longer has a conversion function because it never reaches the slow path iterating over the registered converters. PiperOrigin-RevId: 255584682
This commit is contained in:
parent
d97ccf68c9
commit
de9ba3670d
@ -1030,15 +1030,6 @@ class _EagerTensorBase(Tensor):
|
|||||||
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
|
||||||
|
|
||||||
|
|
||||||
def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
|
|
||||||
_ = name, as_ref
|
|
||||||
if dtype and not dtype.is_compatible_with(t.dtype):
|
|
||||||
raise ValueError(
|
|
||||||
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
|
|
||||||
(dtype.name, t.dtype.name, str(t)))
|
|
||||||
return t
|
|
||||||
tensor_conversion_registry.register_tensor_conversion_function(
|
|
||||||
Tensor, _TensorTensorConversionFunction, 0)
|
|
||||||
register_dense_tensor_like_type(Tensor)
|
register_dense_tensor_like_type(Tensor)
|
||||||
|
|
||||||
|
|
||||||
@ -1172,23 +1163,23 @@ def internal_convert_to_tensor(value,
|
|||||||
ctx=None,
|
ctx=None,
|
||||||
accept_composite_tensors=False):
|
accept_composite_tensors=False):
|
||||||
"""Implementation of the public convert_to_tensor."""
|
"""Implementation of the public convert_to_tensor."""
|
||||||
if ctx is None:
|
|
||||||
ctx = context.context()
|
|
||||||
if isinstance(value, EagerTensor):
|
|
||||||
if ctx.executing_eagerly():
|
|
||||||
if dtype is not None:
|
|
||||||
dtype = dtypes.as_dtype(dtype)
|
|
||||||
value = _TensorTensorConversionFunction(value, dtype=dtype)
|
|
||||||
return value
|
|
||||||
else:
|
|
||||||
graph = get_default_graph()
|
|
||||||
if not graph.building_function:
|
|
||||||
raise RuntimeError("Attempting to capture an EagerTensor without "
|
|
||||||
"building a function.")
|
|
||||||
return graph.capture(value, name=name)
|
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
dtype = dtypes.as_dtype(dtype)
|
dtype = dtypes.as_dtype(dtype)
|
||||||
|
if ctx is None:
|
||||||
|
ctx = context.context()
|
||||||
|
if isinstance(value, EagerTensor) and not ctx.executing_eagerly():
|
||||||
|
graph = get_default_graph()
|
||||||
|
if not graph.building_function:
|
||||||
|
raise RuntimeError("Attempting to capture an EagerTensor without "
|
||||||
|
"building a function.")
|
||||||
|
return graph.capture(value, name=name)
|
||||||
|
elif isinstance(value, Tensor):
|
||||||
|
if dtype is not None and not dtype.is_compatible_with(value.dtype):
|
||||||
|
raise ValueError(
|
||||||
|
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
|
||||||
|
(dtype.name, value.dtype.name, value))
|
||||||
|
return value
|
||||||
|
|
||||||
unwrapped_type = type(value)
|
unwrapped_type = type(value)
|
||||||
conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
|
conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
|
||||||
if conversion_func_list is None:
|
if conversion_func_list is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user