Extracted conversion function caching into tensor_conversion_registry

This change also fixes a race condition in the cache population logic
which could result in cache being populated more than once for a particular
type.

PiperOrigin-RevId: 255590346
This commit is contained in:
Sergei Lebedev 2019-06-28 06:27:50 -07:00 committed by TensorFlower Gardener
parent de9ba3670d
commit 3438a4c411
2 changed files with 27 additions and 19 deletions

View File

@ -81,12 +81,6 @@ _api_usage_gauge = monitoring.BoolGauge(
# pylint: disable=protected-access
_TensorLike = tensor_like._TensorLike
_tensor_conversion_func_registry = \
tensor_conversion_registry._tensor_conversion_func_registry
_tensor_conversion_func_cache = \
tensor_conversion_registry._tensor_conversion_func_cache
_tensor_conversion_func_lock = \
tensor_conversion_registry._tensor_conversion_func_lock
# pylint: enable=protected-access
@ -1180,19 +1174,7 @@ def internal_convert_to_tensor(value,
(dtype.name, value.dtype.name, value))
return value
unwrapped_type = type(value)
conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
if conversion_func_list is None:
with _tensor_conversion_func_lock:
conversion_func_list = []
for _, funcs_at_priority in sorted(
_tensor_conversion_func_registry.items()):
for base_type, conversion_func in funcs_at_priority:
if isinstance(value, base_type):
conversion_func_list.append((base_type, conversion_func))
_tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
for base_type, conversion_func in conversion_func_list:
for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
# If dtype is None but preferred_dtype is not None, we try to
# cast to preferred_dtype first.
ret = None

View File

@ -108,3 +108,29 @@ def register_tensor_conversion_function(base_type,
_tensor_conversion_func_registry[priority] = funcs_at_priority
funcs_at_priority.append((base_type, conversion_func))
_tensor_conversion_func_cache = {}
def get(query):
"""Get conversion function for objects of `cls`.
Args:
query: The type to query for.
Returns:
A list of conversion functions in increasing order of priority.
"""
conversion_funcs = _tensor_conversion_func_cache.get(query)
if conversion_funcs is None:
with _tensor_conversion_func_lock:
# Has another thread populated the cache in the meantime?
conversion_funcs = _tensor_conversion_func_cache.get(query)
if conversion_funcs is None:
conversion_funcs = []
for _, funcs_at_priority in sorted(
_tensor_conversion_func_registry.items()):
conversion_funcs.extend(
(base_type, conversion_func)
for base_type, conversion_func in funcs_at_priority
if issubclass(query, base_type))
_tensor_conversion_func_cache[query] = conversion_funcs
return conversion_funcs