Fix tape recording of attributes of vectorized ops
nest.flatten was mangling some attribute lists Fixes #40895. PiperOrigin-RevId: 339965951 Change-Id: Ibdbbb5a12b2d06a077b62902764ca4849a2b8d2a
This commit is contained in:
parent
b37b0d00f1
commit
3d1f1b062d
tensorflow/python
@ -1772,6 +1772,31 @@ class JacobianTest(test.TestCase):
|
||||
array_ops.reshape(def_function.function(f)(x), [-1]),
|
||||
rtol=1e-3)
|
||||
|
||||
def test_grad_jacobian_conv(self):
|
||||
def _inner(x):
|
||||
kernel = array_ops.ones([3, 3, 1, 9])
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = nn_ops.conv2d(x, kernel, strides=(1, 1), padding='SAME',
|
||||
data_format='NHWC')
|
||||
reduced = math_ops.reduce_sum(y ** 2., axis=[2, 3])
|
||||
return math_ops.reduce_sum(tape.batch_jacobian(reduced, x))
|
||||
|
||||
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
||||
def_function.function(_inner), [array_ops.ones([10, 4, 4, 1])])
|
||||
self.assertAllClose(numerical, theoretical, rtol=1e-1)
|
||||
|
||||
@def_function.function
|
||||
def _outer():
|
||||
with backprop.GradientTape() as tape:
|
||||
x = array_ops.ones([10, 4, 4, 1])
|
||||
tape.watch(x)
|
||||
y = _inner(x)
|
||||
return tape.gradient(y, x)
|
||||
|
||||
self.assertAllClose(array_ops.reshape(numerical, [-1]),
|
||||
array_ops.reshape(_outer(), [-1]), rtol=1e-1)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_indexed_slices(self):
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
|
@ -1001,7 +1001,11 @@ def _create_op(op_type, inputs, op_dtypes, attrs=None):
|
||||
"""Utility to create an op."""
|
||||
op = ops.get_default_graph().create_op(
|
||||
op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
|
||||
flat_attrs = nest.flatten([(str(a), op.get_attr(str(a))) for a in attrs])
|
||||
flat_attrs = []
|
||||
# The tape expects an alternating flat list of names and attribute values.
|
||||
for a in attrs:
|
||||
flat_attrs.append(str(a))
|
||||
flat_attrs.append(op.get_attr(str(a)))
|
||||
execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:])
|
||||
return op
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user