Optimize tf.device in Eager mode by caching to avoid repeated device spec

merging.

PiperOrigin-RevId: 167912007
This commit is contained in:
A. Unique TensorFlower 2017-09-07 14:28:50 -07:00 committed by TensorFlower Gardener
parent 2494aa452b
commit 90dad32968

View File

@ -35,6 +35,12 @@ EAGER_MODE = 1
# Default execution mode.
_default_mode = GRAPH_MODE
# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
# new_device_spec).
# Note that we do not protect this with a lock and instead rely on python's GIL
# and the idempotent nature of writes to provide thread safety.
_device_parsing_cache = {}
# TODO(agarwal): better name ?
class _EagerContext(threading.local):
@ -197,22 +203,33 @@ class Context(object):
eager_context = self._eager_context
old_device_name = eager_context.device_name
old_device_spec = eager_context.device_spec
if name is not None:
if not isinstance(name, str):
raise ValueError("Expecting a string device name. Got %s(%s)" %
(type(name), name))
device_spec = pydev.DeviceSpec.from_string(name)
if old_device_name:
new_device_spec = copy.copy(old_device_spec)
cache_key = (old_device_name, name)
try:
new_device_name, new_device_spec = _device_parsing_cache[cache_key]
except TypeError:
# Error while trying to compute the cache key.
raise ValueError("Expecting a string device name. Got %s(%s)" %
(type(name), name))
except KeyError:
# Handle a cache miss.
if name is not None:
if not isinstance(name, str):
raise ValueError("Expecting a string device name. Got %s(%s)" %
(type(name), name))
device_spec = pydev.DeviceSpec.from_string(name)
if old_device_name:
new_device_spec = copy.copy(old_device_spec)
else:
new_device_spec = pydev.DeviceSpec.from_string(
"/job:localhost/replica:0/task:0/device:CPU:0")
new_device_spec.merge_from(device_spec)
else:
new_device_spec = pydev.DeviceSpec.from_string(
"/job:localhost/replica:0/task:0/device:CPU:0")
new_device_spec.merge_from(device_spec)
else:
new_device_spec = pydev.DeviceSpec.from_string("")
new_device_spec = pydev.DeviceSpec.from_string("")
new_device_name = new_device_spec.to_string()
_device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
try:
eager_context.device_name = new_device_spec.to_string()
eager_context.device_name = new_device_name
eager_context.device_spec = new_device_spec
yield
finally: