Optimize tf.device in Eager mode by caching to avoid repeated device spec
merging. PiperOrigin-RevId: 167912007
This commit is contained in:
parent
2494aa452b
commit
90dad32968
@ -35,6 +35,12 @@ EAGER_MODE = 1
|
||||
# Default execution mode.
|
||||
_default_mode = GRAPH_MODE
|
||||
|
||||
# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
|
||||
# new_device_spec).
|
||||
# Note that we do not protect this with a lock and instead rely on python's GIL
|
||||
# and the idempotent nature of writes to provide thread safety.
|
||||
_device_parsing_cache = {}
|
||||
|
||||
|
||||
# TODO(agarwal): better name ?
|
||||
class _EagerContext(threading.local):
|
||||
@ -197,22 +203,33 @@ class Context(object):
|
||||
eager_context = self._eager_context
|
||||
old_device_name = eager_context.device_name
|
||||
old_device_spec = eager_context.device_spec
|
||||
if name is not None:
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("Expecting a string device name. Got %s(%s)" %
|
||||
(type(name), name))
|
||||
device_spec = pydev.DeviceSpec.from_string(name)
|
||||
if old_device_name:
|
||||
new_device_spec = copy.copy(old_device_spec)
|
||||
cache_key = (old_device_name, name)
|
||||
try:
|
||||
new_device_name, new_device_spec = _device_parsing_cache[cache_key]
|
||||
except TypeError:
|
||||
# Error while trying to compute the cache key.
|
||||
raise ValueError("Expecting a string device name. Got %s(%s)" %
|
||||
(type(name), name))
|
||||
except KeyError:
|
||||
# Handle a cache miss.
|
||||
if name is not None:
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("Expecting a string device name. Got %s(%s)" %
|
||||
(type(name), name))
|
||||
device_spec = pydev.DeviceSpec.from_string(name)
|
||||
if old_device_name:
|
||||
new_device_spec = copy.copy(old_device_spec)
|
||||
else:
|
||||
new_device_spec = pydev.DeviceSpec.from_string(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0")
|
||||
new_device_spec.merge_from(device_spec)
|
||||
else:
|
||||
new_device_spec = pydev.DeviceSpec.from_string(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0")
|
||||
new_device_spec.merge_from(device_spec)
|
||||
else:
|
||||
new_device_spec = pydev.DeviceSpec.from_string("")
|
||||
new_device_spec = pydev.DeviceSpec.from_string("")
|
||||
new_device_name = new_device_spec.to_string()
|
||||
_device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
|
||||
|
||||
try:
|
||||
eager_context.device_name = new_device_spec.to_string()
|
||||
eager_context.device_name = new_device_name
|
||||
eager_context.device_spec = new_device_spec
|
||||
yield
|
||||
finally:
|
||||
|
Loading…
Reference in New Issue
Block a user