Forwardprop: fix variables inside functions, test with layers
We were trying to use the EagerTensor variable handle inside the function. Stops doing that, and also makes it not segfault if it were to do that. A couple misc. fixes to test infrastructure to support the new tests, e.g. gradient checker didn't like unconnected gradients. PiperOrigin-RevId: 266958407
This commit is contained in:
parent
619ad608cb
commit
7403920848
@ -922,6 +922,11 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
for (const TapeTensor& output_tensor : output_tensors) {
|
||||
// Ownership of `aid` transferred to CallBackwardFunction below.
|
||||
Gradient* aid = vspace_.Ones(output_tensor);
|
||||
if (TF_PREDICT_FALSE(aid == nullptr)) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Failed to create ones tensor for tensor ", output_tensor.GetID(),
|
||||
" with dtype ", output_tensor.GetDType());
|
||||
}
|
||||
forwardprop_aids.push_back(aid);
|
||||
int64 aid_id = vspace_.TensorId(aid);
|
||||
sources.push_back(aid_id);
|
||||
|
@ -187,7 +187,9 @@ class ForwardGradientAccumulator(object):
|
||||
"floating (e.g. tf.float32), got %r", 5, t.dtype)
|
||||
g = ops.convert_to_tensor(g, dtype=t.dtype)
|
||||
if hasattr(t, "handle"):
|
||||
t = t.handle
|
||||
# Run convert_to_tensor to get the captured handle from whichever
|
||||
# function we're running if necessary.
|
||||
t = ops.convert_to_tensor(t.handle)
|
||||
pywrap_tensorflow.TFE_Py_ForwardAccumulatorWatch(self._accumulator, t, g)
|
||||
|
||||
def jvp(self, target):
|
||||
@ -209,7 +211,7 @@ class ForwardGradientAccumulator(object):
|
||||
raise ValueError("Called jvp() without first tracing anything.")
|
||||
def _fetch_jvp(tensor):
|
||||
if hasattr(tensor, "handle"):
|
||||
tensor = tensor.handle
|
||||
tensor = ops.convert_to_tensor(tensor.handle)
|
||||
return pywrap_tensorflow.TFE_Py_ForwardAccumulatorJVP(
|
||||
self._accumulator, tensor)
|
||||
return nest.map_structure(_fetch_jvp, target)
|
||||
|
@ -31,8 +31,13 @@ from tensorflow.python.eager import forwardprop
|
||||
from tensorflow.python.eager import forwardprop_util
|
||||
from tensorflow.python.eager import tape as tape_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.layers import convolutional
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import normalization_v2
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
@ -59,7 +64,12 @@ def _jvp(f, primals, tangents):
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
acc.watch(primals, tangents)
|
||||
primals_out = f(*primals)
|
||||
return primals_out, acc.jvp(primals_out)
|
||||
tangents_out = acc.jvp(primals_out)
|
||||
if primals_out is not None and tangents_out is None:
|
||||
# TODO(allenl): Support UnconnectedGradients as an accumulator constructor
|
||||
# argument.
|
||||
return primals_out, array_ops.zeros_like(primals_out)
|
||||
return primals_out, tangents_out
|
||||
|
||||
|
||||
def _jacfwd(f, primals):
|
||||
@ -288,6 +298,66 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
def testElementwiseNNOps(self, value, op_fn):
|
||||
_test_gradients(self, op_fn, [constant_op.constant(value)], order=3)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
[("Dense", [[0.1]], functools.partial(core.Dense, 5)),
|
||||
("Conv2D",
|
||||
np.reshape(np.arange(start=-1., stop=1., step=2. / (1 * 2 * 4 * 4)),
|
||||
[1, 2, 4, 4]),
|
||||
functools.partial(convolutional.Conv2D, 2, 2),
|
||||
1e-4),
|
||||
("BatchNorm", [[0.1], [0.2], [-0.3]],
|
||||
normalization_v2.BatchNormalization)])
|
||||
def testKerasLayers(self, value, op_fn, atol=1e-6):
|
||||
layer = op_fn()
|
||||
input_value = constant_op.constant(value, dtype=dtypes.float32)
|
||||
_test_gradients(
|
||||
self, layer, [input_value], atol=atol,
|
||||
# These are linear, so second-order is pretty boring.
|
||||
order=2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
[("Function", def_function.function),
|
||||
("NoFunction", lambda f: f)])
|
||||
def testVariablesHVP(self, decorator):
|
||||
|
||||
class _Model(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
self._first_dense = core.Dense(18)
|
||||
self._conv = convolutional.Conv2D(2, 2)
|
||||
self._norm = normalization_v2.BatchNormalization()
|
||||
self._second_dense = core.Dense(1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self._first_dense(x)
|
||||
x = nn_ops.relu(x)
|
||||
x = self._norm(x)
|
||||
x = nn_ops.relu(self._conv(array_ops.reshape(x, [-1, 2, 3, 3])))
|
||||
return self._second_dense(x)
|
||||
|
||||
model = _Model()
|
||||
def _loss():
|
||||
input_value = constant_op.constant([[-0.5, 1.], [0.5, -1.]])
|
||||
target = constant_op.constant([[-1.], [2.]])
|
||||
return math_ops.reduce_sum((model(input_value) - target) ** 2.)
|
||||
|
||||
@decorator
|
||||
def _compute_hvps():
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = _loss()
|
||||
vector = tape.gradient(loss, model.trainable_variables)
|
||||
variable_input_fn = lambda unused_variables: _loss()
|
||||
forward_over_back_hvp = _hvp(
|
||||
variable_input_fn, [model.trainable_variables], [vector])
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(model.trainable_variables)
|
||||
loss = _loss()
|
||||
first_grads = tape.gradient(loss, model.trainable_variables)
|
||||
back_over_back_hvp = tape.gradient(
|
||||
first_grads, model.trainable_variables, output_gradients=vector)
|
||||
return forward_over_back_hvp, back_over_back_hvp
|
||||
self.assertAllClose(*_compute_hvps(), rtol=1e-5, atol=1e-5)
|
||||
|
||||
def testPushPopAccumulatorState(self):
|
||||
# Note that this example is somewhat contrived. push_forwardprop_state is
|
||||
# probably only useful in practice for building functions that compute jvps
|
||||
@ -571,6 +641,30 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
x2 = v + .1
|
||||
self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testVariableWatchedFunction(self):
|
||||
|
||||
class _Model(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
self._v = None
|
||||
|
||||
@def_function.function
|
||||
def compute_jvps(self):
|
||||
if self._v is None:
|
||||
self._v = variables.Variable([1., 2., 3.])
|
||||
with forwardprop.ForwardGradientAccumulator() as acc:
|
||||
acc.watch(self._v, constant_op.constant([.1, -.2, .3]))
|
||||
x = self._v * 2.
|
||||
x2 = self._v + .1
|
||||
return acc.jvp((self._v, x, x2))
|
||||
|
||||
model = _Model()
|
||||
v_jvp, x_jvp, x2_jvp = model.compute_jvps()
|
||||
self.assertAllClose([.1, -.2, .3], v_jvp)
|
||||
self.assertAllClose([.2, -.4, .6], x_jvp)
|
||||
self.assertAllClose([.1, -.2, .3], x2_jvp)
|
||||
|
||||
# NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this
|
||||
# test... could be something wrong with the test decorator, or some sort of
|
||||
# nondeterminstic caching.
|
||||
|
@ -175,7 +175,7 @@ def _compute_theoretical_jacobian(f, y_shape, y_dtype, xs, param):
|
||||
r_begin = i * x_val_size
|
||||
r_end = r_begin + x_val_size
|
||||
jacobian[r_begin:r_end, col] += v.flat
|
||||
else:
|
||||
elif grad is not None:
|
||||
jacobian[:, col] = grad.ravel().view(jacobian.dtype)
|
||||
|
||||
# If the output is empty, run the gradients at least once and make sure
|
||||
|
Loading…
x
Reference in New Issue
Block a user