jacobian: allow case where output is not dependent on input.
PiperOrigin-RevId: 209645191
This commit is contained in:
parent
fce0a4eaab
commit
dc62ab7a7c
@ -46,6 +46,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters):
|
||||
"""
|
||||
|
||||
flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
|
||||
is_none_list = []
|
||||
|
||||
def while_body(i, *ta_list):
|
||||
"""Body of while loop."""
|
||||
@ -56,10 +57,13 @@ def for_loop(loop_fn, loop_fn_dtypes, iters):
|
||||
"actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
|
||||
len(fn_output)))
|
||||
outputs = []
|
||||
del is_none_list[:]
|
||||
is_none_list.extend([x is None for x in fn_output])
|
||||
for out, ta in zip(fn_output, ta_list):
|
||||
# TODO(agarwal): support returning Operation objects from loop_fn.
|
||||
assert isinstance(out, ops.Tensor)
|
||||
outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
|
||||
if out is not None:
|
||||
ta = ta.write(i, array_ops.expand_dims(out, 0))
|
||||
outputs.append(ta)
|
||||
return tuple([i + 1] + outputs)
|
||||
|
||||
ta_list = control_flow_ops.while_loop(
|
||||
@ -69,7 +73,10 @@ def for_loop(loop_fn, loop_fn_dtypes, iters):
|
||||
])[1:]
|
||||
|
||||
# TODO(rachelim): enable this for sparse tensors
|
||||
return nest.pack_sequence_as(loop_fn_dtypes, [ta.concat() for ta in ta_list])
|
||||
|
||||
output = [None if is_none else ta.concat()
|
||||
for ta, is_none in zip(ta_list, is_none_list)]
|
||||
return nest.pack_sequence_as(loop_fn_dtypes, output)
|
||||
|
||||
|
||||
def pfor(loop_fn, iters):
|
||||
|
@ -61,6 +61,7 @@ def jacobian(output, inputs, use_pfor=True):
|
||||
loop_fn, [output.dtype] * len(flat_inputs), output_size)
|
||||
|
||||
for i, out in enumerate(pfor_outputs):
|
||||
if out is not None:
|
||||
new_shape = array_ops.concat(
|
||||
[output_shape, array_ops.shape(out)[1:]], axis=0)
|
||||
out = array_ops.reshape(out, new_shape)
|
||||
@ -119,6 +120,8 @@ def batch_jacobian(output, inp, use_pfor=True):
|
||||
else:
|
||||
pfor_output = control_flow_ops.for_loop(loop_fn, output.dtype,
|
||||
output_row_size)
|
||||
if pfor_output is None:
|
||||
return None
|
||||
pfor_output = array_ops.reshape(pfor_output,
|
||||
[output_row_size, batch_size, -1])
|
||||
output = array_ops.transpose(pfor_output, [1, 0, 2])
|
||||
|
@ -333,6 +333,13 @@ class GradientsTest(test.TestCase):
|
||||
for i in range(n):
|
||||
self.assertAllClose(outputs[i], outputs[i + n], rtol=rtol, atol=atol)
|
||||
|
||||
def test_no_path(self):
|
||||
for grad_func in [gradients.jacobian, gradients.batch_jacobian]:
|
||||
for use_pfor in [True, False]:
|
||||
x = constant_op.constant([[1.0]])
|
||||
y = constant_op.constant([[2.0]])
|
||||
self.assertIsNone(grad_func(y, x, use_pfor=use_pfor))
|
||||
|
||||
def test_jacobian_fixed_shape(self):
|
||||
x = random_ops.random_uniform([2, 2])
|
||||
y = math_ops.matmul(x, x, transpose_a=True)
|
||||
|
@ -1070,6 +1070,8 @@ class PFor(object):
|
||||
If y does not need to be converted, it returns y as is. Else it returns
|
||||
the "converted value" corresponding to y.
|
||||
"""
|
||||
if y is None:
|
||||
return None
|
||||
if isinstance(y, sparse_tensor.SparseTensor):
|
||||
return self._convert_sparse(y)
|
||||
output = self._convert_helper(y)
|
||||
|
Loading…
Reference in New Issue
Block a user