Optimize tf.device in Eager mode by caching to avoid repeated device spec
merging. PiperOrigin-RevId: 167912007
This commit is contained in:
parent
2494aa452b
commit
90dad32968
@ -35,6 +35,12 @@ EAGER_MODE = 1
|
|||||||
# Default execution mode.
|
# Default execution mode.
|
||||||
_default_mode = GRAPH_MODE
|
_default_mode = GRAPH_MODE
|
||||||
|
|
||||||
|
# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
|
||||||
|
# new_device_spec).
|
||||||
|
# Note that we do not protect this with a lock and instead rely on python's GIL
|
||||||
|
# and the idempotent nature of writes to provide thread safety.
|
||||||
|
_device_parsing_cache = {}
|
||||||
|
|
||||||
|
|
||||||
# TODO(agarwal): better name ?
|
# TODO(agarwal): better name ?
|
||||||
class _EagerContext(threading.local):
|
class _EagerContext(threading.local):
|
||||||
@ -197,22 +203,33 @@ class Context(object):
|
|||||||
eager_context = self._eager_context
|
eager_context = self._eager_context
|
||||||
old_device_name = eager_context.device_name
|
old_device_name = eager_context.device_name
|
||||||
old_device_spec = eager_context.device_spec
|
old_device_spec = eager_context.device_spec
|
||||||
if name is not None:
|
cache_key = (old_device_name, name)
|
||||||
if not isinstance(name, str):
|
try:
|
||||||
raise ValueError("Expecting a string device name. Got %s(%s)" %
|
new_device_name, new_device_spec = _device_parsing_cache[cache_key]
|
||||||
(type(name), name))
|
except TypeError:
|
||||||
device_spec = pydev.DeviceSpec.from_string(name)
|
# Error while trying to compute the cache key.
|
||||||
if old_device_name:
|
raise ValueError("Expecting a string device name. Got %s(%s)" %
|
||||||
new_device_spec = copy.copy(old_device_spec)
|
(type(name), name))
|
||||||
|
except KeyError:
|
||||||
|
# Handle a cache miss.
|
||||||
|
if name is not None:
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise ValueError("Expecting a string device name. Got %s(%s)" %
|
||||||
|
(type(name), name))
|
||||||
|
device_spec = pydev.DeviceSpec.from_string(name)
|
||||||
|
if old_device_name:
|
||||||
|
new_device_spec = copy.copy(old_device_spec)
|
||||||
|
else:
|
||||||
|
new_device_spec = pydev.DeviceSpec.from_string(
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0")
|
||||||
|
new_device_spec.merge_from(device_spec)
|
||||||
else:
|
else:
|
||||||
new_device_spec = pydev.DeviceSpec.from_string(
|
new_device_spec = pydev.DeviceSpec.from_string("")
|
||||||
"/job:localhost/replica:0/task:0/device:CPU:0")
|
new_device_name = new_device_spec.to_string()
|
||||||
new_device_spec.merge_from(device_spec)
|
_device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
|
||||||
else:
|
|
||||||
new_device_spec = pydev.DeviceSpec.from_string("")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
eager_context.device_name = new_device_spec.to_string()
|
eager_context.device_name = new_device_name
|
||||||
eager_context.device_spec = new_device_spec
|
eager_context.device_spec = new_device_spec
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user