Fix tape recording of attributes of vectorized ops
nest.flatten was mangling some attribute lists Fixes #40895. PiperOrigin-RevId: 339965951 Change-Id: Ibdbbb5a12b2d06a077b62902764ca4849a2b8d2a
This commit is contained in:
parent
b37b0d00f1
commit
3d1f1b062d
@ -1772,6 +1772,31 @@ class JacobianTest(test.TestCase):
|
|||||||
array_ops.reshape(def_function.function(f)(x), [-1]),
|
array_ops.reshape(def_function.function(f)(x), [-1]),
|
||||||
rtol=1e-3)
|
rtol=1e-3)
|
||||||
|
|
||||||
|
def test_grad_jacobian_conv(self):
|
||||||
|
def _inner(x):
|
||||||
|
kernel = array_ops.ones([3, 3, 1, 9])
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(x)
|
||||||
|
y = nn_ops.conv2d(x, kernel, strides=(1, 1), padding='SAME',
|
||||||
|
data_format='NHWC')
|
||||||
|
reduced = math_ops.reduce_sum(y ** 2., axis=[2, 3])
|
||||||
|
return math_ops.reduce_sum(tape.batch_jacobian(reduced, x))
|
||||||
|
|
||||||
|
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
||||||
|
def_function.function(_inner), [array_ops.ones([10, 4, 4, 1])])
|
||||||
|
self.assertAllClose(numerical, theoretical, rtol=1e-1)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def _outer():
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
x = array_ops.ones([10, 4, 4, 1])
|
||||||
|
tape.watch(x)
|
||||||
|
y = _inner(x)
|
||||||
|
return tape.gradient(y, x)
|
||||||
|
|
||||||
|
self.assertAllClose(array_ops.reshape(numerical, [-1]),
|
||||||
|
array_ops.reshape(_outer(), [-1]), rtol=1e-1)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_indexed_slices(self):
|
def test_indexed_slices(self):
|
||||||
with backprop.GradientTape(persistent=True) as g:
|
with backprop.GradientTape(persistent=True) as g:
|
||||||
|
@ -1001,7 +1001,11 @@ def _create_op(op_type, inputs, op_dtypes, attrs=None):
|
|||||||
"""Utility to create an op."""
|
"""Utility to create an op."""
|
||||||
op = ops.get_default_graph().create_op(
|
op = ops.get_default_graph().create_op(
|
||||||
op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
|
op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
|
||||||
flat_attrs = nest.flatten([(str(a), op.get_attr(str(a))) for a in attrs])
|
flat_attrs = []
|
||||||
|
# The tape expects an alternating flat list of names and attribute values.
|
||||||
|
for a in attrs:
|
||||||
|
flat_attrs.append(str(a))
|
||||||
|
flat_attrs.append(op.get_attr(str(a)))
|
||||||
execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:])
|
execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:])
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user