gradient_checker: Convert arguments to tensors.
Without this change, the test added to gradient_checker_v2_test.py would fail with an error like: AttributeError: 'numpy.ndarray' object has no attribute '_id' in the call to "tape.watch(x)". Since gradient_checker_v2.compute_gradients() requires that the "x" argument be a list of Tensors, it seems reasonable to ensure that the "f" argument is provided with Tensors. While at it, also check that the "x" argument is a list to provide a clearer error message than would be obtained without this check. PiperOrigin-RevId: 223820537
This commit is contained in:
parent
2dcceda2d2
commit
556e5ca311
@ -88,14 +88,18 @@ def _prepare(f, xs_dtypes):
|
||||
a function that will be evaluated in both graph and eager mode
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
return f
|
||||
|
||||
def decorated_eager(*xs_data):
|
||||
return f(*map(ops.convert_to_tensor, xs_data))
|
||||
|
||||
return decorated_eager
|
||||
xs = [array_ops.placeholder(x_dtype) for x_dtype in xs_dtypes]
|
||||
y = f(*xs)
|
||||
sess = ops.get_default_session()
|
||||
def decorated(*xs_data):
|
||||
def decorated_graph(*xs_data):
|
||||
xs_data = [_to_numpy(a) for a in xs_data]
|
||||
return sess.run(y, feed_dict=dict(zip(xs, xs_data)))
|
||||
return decorated
|
||||
return decorated_graph
|
||||
|
||||
|
||||
def _compute_theoretical_jacobian(f, y_shape, y_dtype, xs, param):
|
||||
@ -288,6 +292,9 @@ def compute_gradient(f, x, delta=1e-3):
|
||||
Raises:
|
||||
ValueError: If result is empty but the gradient is nonzero.
|
||||
"""
|
||||
if not isinstance(x, list):
|
||||
raise ValueError(
|
||||
"`x` must be a list of Tensors (arguments to `f`), not a %s" % type(x))
|
||||
return _compute_gradient_list(f, x, delta)
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -177,6 +178,19 @@ class GradientCheckerTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(AssertionError, "False is not true"):
|
||||
self.assertTrue(error < 1.0)
|
||||
|
||||
def testGradGrad(self):
|
||||
|
||||
def f(x):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = math_ops.square(x)
|
||||
z = math_ops.square(y)
|
||||
return tape.gradient(z, x)
|
||||
|
||||
analytical, numerical = gradient_checker.compute_gradient(f, [2.0])
|
||||
self.assertAllEqual([[[48.]]], analytical)
|
||||
self.assertAllClose([[[48.]]], numerical, rtol=1e-4)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class MiniMNISTTest(test.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user