Improve performance of tf.subscribe on large graphs by caching control outputs.
Change: 138473617
This commit is contained in:
parent
10d6df2f95
commit
e1dddfb401
tensorflow/python/framework
@ -57,25 +57,51 @@ def _recursive_apply(tensors, apply_fn):
|
||||
(tensors, tensors_type))
|
||||
|
||||
|
||||
def _control_outputs(op):
|
||||
"""Returns the control_input consumers for the supplied `Operation`.
|
||||
class _ControlOutputCache(object):
|
||||
"""Helper class to manage calculating and caching control_outputs in graph."""
|
||||
|
||||
Args:
|
||||
op: The `Operation` to find consumers of.
|
||||
Yields:
|
||||
A list of ops that have op as a control dependency.
|
||||
"""
|
||||
for o in op.graph.get_operations():
|
||||
if op in o.control_inputs:
|
||||
yield o
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def calc_control_outputs(self, graph):
|
||||
"""Returns the map of control_outputs for a given graph.
|
||||
|
||||
Args:
|
||||
graph: The graph to parse.
|
||||
Returns:
|
||||
A map of the control outputs.
|
||||
"""
|
||||
control_outputs = {}
|
||||
for op in graph.get_operations():
|
||||
for control_input in op.control_inputs:
|
||||
if control_input not in control_outputs:
|
||||
control_outputs[control_input] = set()
|
||||
control_outputs[control_input].add(op)
|
||||
return control_outputs
|
||||
|
||||
def get_control_outputs(self, op):
|
||||
"""Return the control outputs for a given op.
|
||||
|
||||
Args:
|
||||
op: The op to fetch control outputs for.
|
||||
Returns:
|
||||
Iterable of control output ops.
|
||||
"""
|
||||
if op.graph not in self.cache:
|
||||
control_outputs = self.calc_control_outputs(op.graph)
|
||||
self.cache[op.graph] = control_outputs
|
||||
else:
|
||||
control_outputs = self.cache[op.graph]
|
||||
return control_outputs.get(op, [])
|
||||
|
||||
|
||||
def _subscribe(tensor, side_effects):
|
||||
def _subscribe(tensor, side_effects, control_cache):
|
||||
"""Helper method that subscribes a single tensor to a list of side_effects.
|
||||
|
||||
Args:
|
||||
tensor: `tf.Tensor`
|
||||
side_effects: List of side_effect functions see subscribe for details.
|
||||
control_cache: `_ControlOutputCache` helper to get control_outputs faster.
|
||||
Returns:
|
||||
The modified replacement to the passed in tensor which triggers the side
|
||||
effects.
|
||||
@ -84,7 +110,7 @@ def _subscribe(tensor, side_effects):
|
||||
for consumer_op in list(tensor.consumers()): # explicit copy
|
||||
update_input.append((consumer_op, list(consumer_op.inputs).index(tensor)))
|
||||
|
||||
update_control_input = list(_control_outputs(tensor.op))
|
||||
update_control_input = control_cache.get_control_outputs(tensor.op)
|
||||
|
||||
# Trailing slash on name scope to replace the scope.
|
||||
name_scope = tensor.op.name + '/subscription/'
|
||||
@ -141,4 +167,8 @@ def subscribe(tensors, side_effects):
|
||||
"""
|
||||
if not hasattr(side_effects, '__iter__'):
|
||||
side_effects = [side_effects]
|
||||
return _recursive_apply(tensors, lambda t: _subscribe(t, side_effects))
|
||||
|
||||
control_outputs = _ControlOutputCache()
|
||||
result = _recursive_apply(
|
||||
tensors, lambda t: _subscribe(t, side_effects, control_outputs))
|
||||
return result
|
||||
|
@ -54,6 +54,36 @@ class SubscribeTest(test_util.TensorFlowTestCase):
|
||||
self.assertEquals(d_out, [42])
|
||||
self.assertEquals(shared, [2, 2, 2])
|
||||
|
||||
def testCaching(self):
|
||||
"""Confirm caching of control output is recacluated between calls."""
|
||||
a = tf.constant(1)
|
||||
b = tf.constant(2)
|
||||
with tf.control_dependencies([a]):
|
||||
c = tf.constant(42)
|
||||
|
||||
shared = {}
|
||||
|
||||
def sub(t):
|
||||
shared[t] = shared.get(t, 0) + 1
|
||||
return t
|
||||
|
||||
a = subscribe.subscribe(a, lambda t: tf.py_func(sub, [t], [t.dtype]))
|
||||
|
||||
with tf.control_dependencies([b]):
|
||||
d = tf.constant(11)
|
||||
|
||||
# If it was using outdated cached control_outputs then
|
||||
# evaling would not trigger the new subscription.
|
||||
b = subscribe.subscribe(b, lambda t: tf.py_func(sub, [t], [t.dtype]))
|
||||
|
||||
with self.test_session() as sess:
|
||||
c_out = sess.run([c])
|
||||
d_out = sess.run([d])
|
||||
|
||||
self.assertEquals(c_out, [42])
|
||||
self.assertEquals(d_out, [11])
|
||||
self.assertEquals(shared, {2: 1, 1: 1})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user