jacobian: allow case where output is not dependent on input.
PiperOrigin-RevId: 209645191
This commit is contained in:
parent
fce0a4eaab
commit
dc62ab7a7c
@ -46,6 +46,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
|
flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
|
||||||
|
is_none_list = []
|
||||||
|
|
||||||
def while_body(i, *ta_list):
|
def while_body(i, *ta_list):
|
||||||
"""Body of while loop."""
|
"""Body of while loop."""
|
||||||
@ -56,10 +57,13 @@ def for_loop(loop_fn, loop_fn_dtypes, iters):
|
|||||||
"actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
|
"actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
|
||||||
len(fn_output)))
|
len(fn_output)))
|
||||||
outputs = []
|
outputs = []
|
||||||
|
del is_none_list[:]
|
||||||
|
is_none_list.extend([x is None for x in fn_output])
|
||||||
for out, ta in zip(fn_output, ta_list):
|
for out, ta in zip(fn_output, ta_list):
|
||||||
# TODO(agarwal): support returning Operation objects from loop_fn.
|
# TODO(agarwal): support returning Operation objects from loop_fn.
|
||||||
assert isinstance(out, ops.Tensor)
|
if out is not None:
|
||||||
outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
|
ta = ta.write(i, array_ops.expand_dims(out, 0))
|
||||||
|
outputs.append(ta)
|
||||||
return tuple([i + 1] + outputs)
|
return tuple([i + 1] + outputs)
|
||||||
|
|
||||||
ta_list = control_flow_ops.while_loop(
|
ta_list = control_flow_ops.while_loop(
|
||||||
@ -69,7 +73,10 @@ def for_loop(loop_fn, loop_fn_dtypes, iters):
|
|||||||
])[1:]
|
])[1:]
|
||||||
|
|
||||||
# TODO(rachelim): enable this for sparse tensors
|
# TODO(rachelim): enable this for sparse tensors
|
||||||
return nest.pack_sequence_as(loop_fn_dtypes, [ta.concat() for ta in ta_list])
|
|
||||||
|
output = [None if is_none else ta.concat()
|
||||||
|
for ta, is_none in zip(ta_list, is_none_list)]
|
||||||
|
return nest.pack_sequence_as(loop_fn_dtypes, output)
|
||||||
|
|
||||||
|
|
||||||
def pfor(loop_fn, iters):
|
def pfor(loop_fn, iters):
|
||||||
|
@ -61,9 +61,10 @@ def jacobian(output, inputs, use_pfor=True):
|
|||||||
loop_fn, [output.dtype] * len(flat_inputs), output_size)
|
loop_fn, [output.dtype] * len(flat_inputs), output_size)
|
||||||
|
|
||||||
for i, out in enumerate(pfor_outputs):
|
for i, out in enumerate(pfor_outputs):
|
||||||
new_shape = array_ops.concat(
|
if out is not None:
|
||||||
[output_shape, array_ops.shape(out)[1:]], axis=0)
|
new_shape = array_ops.concat(
|
||||||
out = array_ops.reshape(out, new_shape)
|
[output_shape, array_ops.shape(out)[1:]], axis=0)
|
||||||
|
out = array_ops.reshape(out, new_shape)
|
||||||
pfor_outputs[i] = out
|
pfor_outputs[i] = out
|
||||||
|
|
||||||
return nest.pack_sequence_as(inputs, pfor_outputs)
|
return nest.pack_sequence_as(inputs, pfor_outputs)
|
||||||
@ -119,6 +120,8 @@ def batch_jacobian(output, inp, use_pfor=True):
|
|||||||
else:
|
else:
|
||||||
pfor_output = control_flow_ops.for_loop(loop_fn, output.dtype,
|
pfor_output = control_flow_ops.for_loop(loop_fn, output.dtype,
|
||||||
output_row_size)
|
output_row_size)
|
||||||
|
if pfor_output is None:
|
||||||
|
return None
|
||||||
pfor_output = array_ops.reshape(pfor_output,
|
pfor_output = array_ops.reshape(pfor_output,
|
||||||
[output_row_size, batch_size, -1])
|
[output_row_size, batch_size, -1])
|
||||||
output = array_ops.transpose(pfor_output, [1, 0, 2])
|
output = array_ops.transpose(pfor_output, [1, 0, 2])
|
||||||
|
@ -333,6 +333,13 @@ class GradientsTest(test.TestCase):
|
|||||||
for i in range(n):
|
for i in range(n):
|
||||||
self.assertAllClose(outputs[i], outputs[i + n], rtol=rtol, atol=atol)
|
self.assertAllClose(outputs[i], outputs[i + n], rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
def test_no_path(self):
|
||||||
|
for grad_func in [gradients.jacobian, gradients.batch_jacobian]:
|
||||||
|
for use_pfor in [True, False]:
|
||||||
|
x = constant_op.constant([[1.0]])
|
||||||
|
y = constant_op.constant([[2.0]])
|
||||||
|
self.assertIsNone(grad_func(y, x, use_pfor=use_pfor))
|
||||||
|
|
||||||
def test_jacobian_fixed_shape(self):
|
def test_jacobian_fixed_shape(self):
|
||||||
x = random_ops.random_uniform([2, 2])
|
x = random_ops.random_uniform([2, 2])
|
||||||
y = math_ops.matmul(x, x, transpose_a=True)
|
y = math_ops.matmul(x, x, transpose_a=True)
|
||||||
|
@ -1070,6 +1070,8 @@ class PFor(object):
|
|||||||
If y does not need to be converted, it returns y as is. Else it returns
|
If y does not need to be converted, it returns y as is. Else it returns
|
||||||
the "converted value" corresponding to y.
|
the "converted value" corresponding to y.
|
||||||
"""
|
"""
|
||||||
|
if y is None:
|
||||||
|
return None
|
||||||
if isinstance(y, sparse_tensor.SparseTensor):
|
if isinstance(y, sparse_tensor.SparseTensor):
|
||||||
return self._convert_sparse(y)
|
return self._convert_sparse(y)
|
||||||
output = self._convert_helper(y)
|
output = self._convert_helper(y)
|
||||||
|
Loading…
Reference in New Issue
Block a user