Scope the scalar cache in the context.
PiperOrigin-RevId: 168065417
This commit is contained in:
parent
48deb206ba
commit
0753b0c790
@ -53,6 +53,7 @@ class _EagerContext(threading.local):
|
||||
self.mode = _default_mode
|
||||
self.scope_name = ""
|
||||
self.recording_summaries = False
|
||||
self.scalar_cache = {}
|
||||
|
||||
|
||||
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
|
||||
@ -157,6 +158,10 @@ class Context(object):
|
||||
"""Returns True if current thread is in EAGER mode."""
|
||||
return self._eager_context.mode == EAGER_MODE
|
||||
|
||||
def scalar_cache(self):
|
||||
"""Per-device cache for scalars."""
|
||||
return self._eager_context.scalar_cache
|
||||
|
||||
@property
|
||||
def scope_name(self):
|
||||
"""Returns scope name for the current thread."""
|
||||
|
@ -74,10 +74,6 @@ def _eager_fill(dims, value):
|
||||
return result
|
||||
|
||||
|
||||
# Rely on the GIL for thread-safety.
|
||||
_scalar_cache = {}
|
||||
|
||||
|
||||
def convert_to_eager_tensor(t, dtype=None):
|
||||
"""Converts the given `value` to an `EagerTensor`."""
|
||||
if isinstance(ag_core.getval(t), ops.EagerTensor):
|
||||
@ -88,13 +84,15 @@ def convert_to_eager_tensor(t, dtype=None):
|
||||
# Use a scalar cache. This will put each scalar of each type only once on
|
||||
# each device. Scalars don't use much device memory but copying scalars can
|
||||
# trigger memcpys which are slow.
|
||||
device = context.context().device_name
|
||||
ctx = context.context()
|
||||
device = ctx.device_name
|
||||
cache_key = device, t, dtype, type(t)
|
||||
tensor = _scalar_cache.get(cache_key, None)
|
||||
scalar_cache = ctx.scalar_cache()
|
||||
tensor = scalar_cache.get(cache_key, None)
|
||||
if tensor is not None:
|
||||
return tensor
|
||||
value = ops.EagerTensor(t, dtype=dtype)
|
||||
_scalar_cache[cache_key] = value
|
||||
scalar_cache[cache_key] = value
|
||||
return value
|
||||
return ops.EagerTensor(t, dtype=dtype)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user