Extracted conversion function caching into tensor_conversion_registry
This change also fixes a race condition in the cache population logic which could result in cache being populated more than once for a particular type. PiperOrigin-RevId: 255590346
This commit is contained in:
parent
de9ba3670d
commit
3438a4c411
@ -81,12 +81,6 @@ _api_usage_gauge = monitoring.BoolGauge(
|
|||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
_TensorLike = tensor_like._TensorLike
|
_TensorLike = tensor_like._TensorLike
|
||||||
_tensor_conversion_func_registry = \
|
|
||||||
tensor_conversion_registry._tensor_conversion_func_registry
|
|
||||||
_tensor_conversion_func_cache = \
|
|
||||||
tensor_conversion_registry._tensor_conversion_func_cache
|
|
||||||
_tensor_conversion_func_lock = \
|
|
||||||
tensor_conversion_registry._tensor_conversion_func_lock
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@ -1180,19 +1174,7 @@ def internal_convert_to_tensor(value,
|
|||||||
(dtype.name, value.dtype.name, value))
|
(dtype.name, value.dtype.name, value))
|
||||||
return value
|
return value
|
||||||
|
|
||||||
unwrapped_type = type(value)
|
for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
|
||||||
conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
|
|
||||||
if conversion_func_list is None:
|
|
||||||
with _tensor_conversion_func_lock:
|
|
||||||
conversion_func_list = []
|
|
||||||
for _, funcs_at_priority in sorted(
|
|
||||||
_tensor_conversion_func_registry.items()):
|
|
||||||
for base_type, conversion_func in funcs_at_priority:
|
|
||||||
if isinstance(value, base_type):
|
|
||||||
conversion_func_list.append((base_type, conversion_func))
|
|
||||||
_tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
|
|
||||||
|
|
||||||
for base_type, conversion_func in conversion_func_list:
|
|
||||||
# If dtype is None but preferred_dtype is not None, we try to
|
# If dtype is None but preferred_dtype is not None, we try to
|
||||||
# cast to preferred_dtype first.
|
# cast to preferred_dtype first.
|
||||||
ret = None
|
ret = None
|
||||||
|
@ -108,3 +108,29 @@ def register_tensor_conversion_function(base_type,
|
|||||||
_tensor_conversion_func_registry[priority] = funcs_at_priority
|
_tensor_conversion_func_registry[priority] = funcs_at_priority
|
||||||
funcs_at_priority.append((base_type, conversion_func))
|
funcs_at_priority.append((base_type, conversion_func))
|
||||||
_tensor_conversion_func_cache = {}
|
_tensor_conversion_func_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get(query):
|
||||||
|
"""Get conversion function for objects of `cls`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The type to query for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of conversion functions in increasing order of priority.
|
||||||
|
"""
|
||||||
|
conversion_funcs = _tensor_conversion_func_cache.get(query)
|
||||||
|
if conversion_funcs is None:
|
||||||
|
with _tensor_conversion_func_lock:
|
||||||
|
# Has another thread populated the cache in the meantime?
|
||||||
|
conversion_funcs = _tensor_conversion_func_cache.get(query)
|
||||||
|
if conversion_funcs is None:
|
||||||
|
conversion_funcs = []
|
||||||
|
for _, funcs_at_priority in sorted(
|
||||||
|
_tensor_conversion_func_registry.items()):
|
||||||
|
conversion_funcs.extend(
|
||||||
|
(base_type, conversion_func)
|
||||||
|
for base_type, conversion_func in funcs_at_priority
|
||||||
|
if issubclass(query, base_type))
|
||||||
|
_tensor_conversion_func_cache[query] = conversion_funcs
|
||||||
|
return conversion_funcs
|
||||||
|
Loading…
Reference in New Issue
Block a user