diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index a5a93b7bbe0..c6af7df176f 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -35,6 +35,12 @@ EAGER_MODE = 1 # Default execution mode. _default_mode = GRAPH_MODE +# Cache from (old_device_name, partial_new_device_name) -> (new_device_name, +# new_device_spec). +# Note that we do not protect this with a lock and instead rely on python's GIL +# and the idempotent nature of writes to provide thread safety. +_device_parsing_cache = {} + # TODO(agarwal): better name ? class _EagerContext(threading.local): @@ -197,22 +203,33 @@ class Context(object): eager_context = self._eager_context old_device_name = eager_context.device_name old_device_spec = eager_context.device_spec - if name is not None: - if not isinstance(name, str): - raise ValueError("Expecting a string device name. Got %s(%s)" % - (type(name), name)) - device_spec = pydev.DeviceSpec.from_string(name) - if old_device_name: - new_device_spec = copy.copy(old_device_spec) + cache_key = (old_device_name, name) + try: + new_device_name, new_device_spec = _device_parsing_cache[cache_key] + except TypeError: + # Error while trying to compute the cache key. + raise ValueError("Expecting a string device name. Got %s(%s)" % + (type(name), name)) + except KeyError: + # Handle a cache miss. + if name is not None: + if not isinstance(name, str): + raise ValueError("Expecting a string device name. Got %s(%s)" % + (type(name), name)) + device_spec = pydev.DeviceSpec.from_string(name) + if old_device_name: + new_device_spec = copy.copy(old_device_spec) + else: + new_device_spec = pydev.DeviceSpec.from_string( + "/job:localhost/replica:0/task:0/device:CPU:0") + new_device_spec.merge_from(device_spec) else: - new_device_spec = pydev.DeviceSpec.from_string( - "/job:localhost/replica:0/task:0/device:CPU:0") - new_device_spec.merge_from(device_spec) - else: - new_device_spec = pydev.DeviceSpec.from_string("") + new_device_spec = pydev.DeviceSpec.from_string("") + new_device_name = new_device_spec.to_string() + _device_parsing_cache[cache_key] = (new_device_name, new_device_spec) try: - eager_context.device_name = new_device_spec.to_string() + eager_context.device_name = new_device_name eager_context.device_spec = new_device_spec yield finally: