Scope the scalar cache in the context.

PiperOrigin-RevId: 168065417
This commit is contained in:
Alexandre Passos 2017-09-08 16:52:18 -07:00 committed by TensorFlower Gardener
parent 48deb206ba
commit 0753b0c790
2 changed files with 10 additions and 7 deletions

View File

@ -53,6 +53,7 @@ class _EagerContext(threading.local):
self.mode = _default_mode
self.scope_name = ""
self.recording_summaries = False
self.scalar_cache = {}
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
@ -157,6 +158,10 @@ class Context(object):
"""Returns True if current thread is in EAGER mode."""
return self._eager_context.mode == EAGER_MODE
def scalar_cache(self):
"""Per-device cache for scalars."""
return self._eager_context.scalar_cache
@property
def scope_name(self):
"""Returns scope name for the current thread."""

View File

@ -74,10 +74,6 @@ def _eager_fill(dims, value):
return result
# Rely on the GIL for thread-safety.
_scalar_cache = {}
def convert_to_eager_tensor(t, dtype=None):
"""Converts the given `value` to an `EagerTensor`."""
if isinstance(ag_core.getval(t), ops.EagerTensor):
@ -88,13 +84,15 @@ def convert_to_eager_tensor(t, dtype=None):
# Use a scalar cache. This will put each scalar of each type only once on
# each device. Scalars don't use much device memory but copying scalars can
# trigger memcpys which are slow.
device = context.context().device_name
ctx = context.context()
device = ctx.device_name
cache_key = device, t, dtype, type(t)
tensor = _scalar_cache.get(cache_key, None)
scalar_cache = ctx.scalar_cache()
tensor = scalar_cache.get(cache_key, None)
if tensor is not None:
return tensor
value = ops.EagerTensor(t, dtype=dtype)
_scalar_cache[cache_key] = value
scalar_cache[cache_key] = value
return value
return ops.EagerTensor(t, dtype=dtype)