gradient_checker: Convert arguments to tensors.

Without this change, the test added to gradient_checker_v2_test.py
would fail with an error like:

AttributeError: 'numpy.ndarray' object has no attribute '_id'

in the call to "tape.watch(x)".

Since gradient_checker_v2.compute_gradients() requires that the "x" argument be
a list of Tensors, it seems reasonable to ensure that the "f" argument is
provided with Tensors.

While at it, also check that the "x" argument is a list to provide a clearer
error message than would be obtained without this check.

PiperOrigin-RevId: 223820537
This commit is contained in:
Asim Shankar 2018-12-03 10:38:59 -08:00 committed by TensorFlower Gardener
parent 2dcceda2d2
commit 556e5ca311
2 changed files with 24 additions and 3 deletions

View File

@ -88,14 +88,18 @@ def _prepare(f, xs_dtypes):
a function that will be evaluated in both graph and eager mode
"""
if context.executing_eagerly():
return f
def decorated_eager(*xs_data):
return f(*map(ops.convert_to_tensor, xs_data))
return decorated_eager
xs = [array_ops.placeholder(x_dtype) for x_dtype in xs_dtypes]
y = f(*xs)
sess = ops.get_default_session()
def decorated(*xs_data):
def decorated_graph(*xs_data):
xs_data = [_to_numpy(a) for a in xs_data]
return sess.run(y, feed_dict=dict(zip(xs, xs_data)))
return decorated
return decorated_graph
def _compute_theoretical_jacobian(f, y_shape, y_dtype, xs, param):
@ -288,6 +292,9 @@ def compute_gradient(f, x, delta=1e-3):
Raises:
ValueError: If result is empty but the gradient is nonzero.
"""
if not isinstance(x, list):
raise ValueError(
"`x` must be a list of Tensors (arguments to `f`), not a %s" % type(x))
return _compute_gradient_list(f, x, delta)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -177,6 +178,19 @@ class GradientCheckerTest(test.TestCase):
with self.assertRaisesRegexp(AssertionError, "False is not true"):
self.assertTrue(error < 1.0)
def testGradGrad(self):
def f(x):
with backprop.GradientTape() as tape:
tape.watch(x)
y = math_ops.square(x)
z = math_ops.square(y)
return tape.gradient(z, x)
analytical, numerical = gradient_checker.compute_gradient(f, [2.0])
self.assertAllEqual([[[48.]]], analytical)
self.assertAllClose([[[48.]]], numerical, rtol=1e-4)
@test_util.run_all_in_graph_and_eager_modes
class MiniMNISTTest(test.TestCase):