diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 185cd9a7165..8496a02947f 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -53,6 +53,7 @@ class _EagerContext(threading.local): self.mode = _default_mode self.scope_name = "" self.recording_summaries = False + self.scalar_cache = {} # TODO(agarwal): rename to EagerContext / EagerRuntime ? @@ -157,6 +158,10 @@ class Context(object): """Returns True if current thread is in EAGER mode.""" return self._eager_context.mode == EAGER_MODE + def scalar_cache(self): + """Per-device cache for scalars.""" + return self._eager_context.scalar_cache + @property def scope_name(self): """Returns scope name for the current thread.""" diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 819730a51b9..a859645950d 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -74,10 +74,6 @@ def _eager_fill(dims, value): return result -# Rely on the GIL for thread-safety. -_scalar_cache = {} - - def convert_to_eager_tensor(t, dtype=None): """Converts the given `value` to an `EagerTensor`.""" if isinstance(ag_core.getval(t), ops.EagerTensor): @@ -88,13 +84,15 @@ def convert_to_eager_tensor(t, dtype=None): # Use a scalar cache. This will put each scalar of each type only once on # each device. Scalars don't use much device memory but copying scalars can # trigger memcpys which are slow. - device = context.context().device_name + ctx = context.context() + device = ctx.device_name cache_key = device, t, dtype, type(t) - tensor = _scalar_cache.get(cache_key, None) + scalar_cache = ctx.scalar_cache() + tensor = scalar_cache.get(cache_key, None) if tensor is not None: return tensor value = ops.EagerTensor(t, dtype=dtype) - _scalar_cache[cache_key] = value + scalar_cache[cache_key] = value return value return ops.EagerTensor(t, dtype=dtype)