Add parallel_iterations option to batch_jacobian and jacobian functions.

PiperOrigin-RevId: 223241273
This commit is contained in:
A. Unique TensorFlower 2018-11-28 14:56:02 -08:00 committed by TensorFlower Gardener
parent bbb81ea428
commit a7b3f17a16
4 changed files with 35 additions and 8 deletions

View File

@ -955,6 +955,7 @@ class GradientTape(object):
target, target,
sources, sources,
unconnected_gradients=UnconnectedGradients.NONE, unconnected_gradients=UnconnectedGradients.NONE,
parallel_iterations=None,
experimental_use_pfor=True): experimental_use_pfor=True):
"""Computes the jacobian using operations recorded in context of this tape. """Computes the jacobian using operations recorded in context of this tape.
@ -978,6 +979,8 @@ class GradientTape(object):
alters the value which will be returned if the target and sources are alters the value which will be returned if the target and sources are
unconnected. The possible values and effects are detailed in unconnected. The possible values and effects are detailed in
'UnconnectedGradients' and it defaults to 'none'. 'UnconnectedGradients' and it defaults to 'none'.
parallel_iterations: A knob to control how many iterations are dispatched
in parallel. This knob can be used to control the total memory usage.
experimental_use_pfor: If true, vectorizes the jacobian computation. Else experimental_use_pfor: If true, vectorizes the jacobian computation. Else
falls back to a sequential while_loop. Vectorization can sometimes fail falls back to a sequential while_loop. Vectorization can sometimes fail
or lead to excessive memory usage. This option can be used to disable or lead to excessive memory usage. This option can be used to disable
@ -1016,7 +1019,8 @@ class GradientTape(object):
if experimental_use_pfor: if experimental_use_pfor:
try: try:
output = pfor_ops.pfor(loop_fn, target_size) output = pfor_ops.pfor(loop_fn, target_size,
parallel_iterations=parallel_iterations)
except ValueError as err: except ValueError as err:
six.reraise( six.reraise(
ValueError, ValueError,
@ -1032,7 +1036,8 @@ class GradientTape(object):
" to compute the jacobian with eager execution enabled and with " " to compute the jacobian with eager execution enabled and with "
" experimental_use_pfor set to False.") " experimental_use_pfor set to False.")
output = pfor_ops.for_loop( output = pfor_ops.for_loop(
loop_fn, [target.dtype] * len(flat_sources), target_size) loop_fn, [target.dtype] * len(flat_sources), target_size,
parallel_iterations=parallel_iterations)
for i, out in enumerate(output): for i, out in enumerate(output):
if out is not None: if out is not None:
@ -1049,6 +1054,7 @@ class GradientTape(object):
target, target,
source, source,
unconnected_gradients=UnconnectedGradients.NONE, unconnected_gradients=UnconnectedGradients.NONE,
parallel_iterations=None,
experimental_use_pfor=True): experimental_use_pfor=True):
"""Computes and stacks per-example jacobians. """Computes and stacks per-example jacobians.
@ -1081,6 +1087,8 @@ class GradientTape(object):
alters the value which will be returned if the target and sources are alters the value which will be returned if the target and sources are
unconnected. The possible values and effects are detailed in unconnected. The possible values and effects are detailed in
'UnconnectedGradients' and it defaults to 'none'. 'UnconnectedGradients' and it defaults to 'none'.
parallel_iterations: A knob to control how many iterations are dispatched
in parallel. This knob can be used to control the total memory usage.
experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else
uses a tf.while_loop. uses a tf.while_loop.
@ -1127,7 +1135,8 @@ class GradientTape(object):
if experimental_use_pfor: if experimental_use_pfor:
try: try:
output = pfor_ops.pfor(loop_fn, target_row_size) output = pfor_ops.pfor(loop_fn, target_row_size,
parallel_iterations=parallel_iterations)
except ValueError as err: except ValueError as err:
six.reraise( six.reraise(
ValueError, ValueError,
@ -1142,7 +1151,8 @@ class GradientTape(object):
"GradientTape must be created with persistent=True" "GradientTape must be created with persistent=True"
" to compute the batch_jacobian with eager execution enabled and " " to compute the batch_jacobian with eager execution enabled and "
" with experimental_use_pfor set to False.") " with experimental_use_pfor set to False.")
output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size) output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size,
parallel_iterations=parallel_iterations)
if output is None: if output is None:
return None return None
output = array_ops.reshape(output, output = array_ops.reshape(output,

View File

@ -1303,6 +1303,14 @@ class JacobianTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'No converter'): with self.assertRaisesRegexp(ValueError, 'No converter'):
g.jacobian(y, x, experimental_use_pfor=True) g.jacobian(y, x, experimental_use_pfor=True)
def test_parallel_iterations(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant([[1., 2], [3, 4]])
g.watch(x)
y = math_ops.matmul(x, x)
self.assertAllClose(g.jacobian(y, x, parallel_iterations=2),
g.jacobian(y, x, parallel_iterations=3))
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class BatchJacobianTest(test.TestCase): class BatchJacobianTest(test.TestCase):
@ -1397,5 +1405,14 @@ class BatchJacobianTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'No converter'): with self.assertRaisesRegexp(ValueError, 'No converter'):
g.batch_jacobian(y, x, experimental_use_pfor=True) g.batch_jacobian(y, x, experimental_use_pfor=True)
def test_parallel_iterations(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant([[1., 2], [3, 4]])
g.watch(x)
w = constant_op.constant([[1., 2, 3, 4], [5, 6, 7, 8]])
y = math_ops.matmul(x, w)
self.assertAllClose(g.batch_jacobian(y, x, parallel_iterations=2),
g.batch_jacobian(y, x, parallel_iterations=3))
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -8,7 +8,7 @@ tf_class {
} }
member_method { member_method {
name: "batch_jacobian" name: "batch_jacobian"
argspec: "args=[\'self\', \'target\', \'source\', \'unconnected_gradients\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'True\'], " argspec: "args=[\'self\', \'target\', \'source\', \'unconnected_gradients\', \'parallel_iterations\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "gradient" name: "gradient"
@ -16,7 +16,7 @@ tf_class {
} }
member_method { member_method {
name: "jacobian" name: "jacobian"
argspec: "args=[\'self\', \'target\', \'sources\', \'unconnected_gradients\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'True\'], " argspec: "args=[\'self\', \'target\', \'sources\', \'unconnected_gradients\', \'parallel_iterations\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "reset" name: "reset"

View File

@ -8,7 +8,7 @@ tf_class {
} }
member_method { member_method {
name: "batch_jacobian" name: "batch_jacobian"
argspec: "args=[\'self\', \'target\', \'source\', \'unconnected_gradients\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'True\'], " argspec: "args=[\'self\', \'target\', \'source\', \'unconnected_gradients\', \'parallel_iterations\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "gradient" name: "gradient"
@ -16,7 +16,7 @@ tf_class {
} }
member_method { member_method {
name: "jacobian" name: "jacobian"
argspec: "args=[\'self\', \'target\', \'sources\', \'unconnected_gradients\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'True\'], " argspec: "args=[\'self\', \'target\', \'sources\', \'unconnected_gradients\', \'parallel_iterations\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "reset" name: "reset"