makes gradient_checker_v2 work with sparse tensor reshape.
PiperOrigin-RevId: 301667976 Change-Id: I133c918c8a5703b40d4b5a69ead4bc0ec74f27b8
This commit is contained in:
parent
5041af7879
commit
5b00c31c4b
@ -174,7 +174,6 @@ def _compute_theoretical_jacobian(f, y_shape, y_dtype, xs, param):
|
||||
dy_data_flat[row] = 1
|
||||
grad = _to_numpy(grad_fn(dy_data, *xs)[0])
|
||||
grad = _eval_indexed_slices(grad)
|
||||
dy_data_flat[row] = 0
|
||||
if isinstance(grad, ops.IndexedSlicesValue):
|
||||
for i, v in zip(grad.indices, grad.values):
|
||||
c_begin = i * x_val_size
|
||||
@ -182,6 +181,9 @@ def _compute_theoretical_jacobian(f, y_shape, y_dtype, xs, param):
|
||||
jacobian[row, c_begin:c_end] += v.flat
|
||||
elif grad is not None:
|
||||
jacobian[row, :] = grad.ravel().view(jacobian.dtype)
|
||||
# This reset of `dy_data_flat` needs to happen after `grad` is copied to
|
||||
# `jacobian` because `grad` and `dy_data_flat` may share memory.
|
||||
dy_data_flat[row] = 0
|
||||
|
||||
# If the output is empty, run the gradients at least once and make sure
|
||||
# they produce zeros.
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
@ -30,6 +31,7 @@ from tensorflow.python.ops import \
|
||||
gradient_checker_v2 as gradient_checker
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
# needs this to register gradient for SoftmaxCrossEntropyWithLogits:
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
@ -46,6 +48,20 @@ def _random_complex(shape, dtype):
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class GradientCheckerTest(test.TestCase):
|
||||
|
||||
def testSparseTensorReshape(self):
|
||||
x = constant_op.constant(2.0, shape=(2,))
|
||||
|
||||
def sparse_tensor_reshape(values):
|
||||
sparse = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=values, dense_shape=[3, 4])
|
||||
sparse = sparse_ops.sparse_reshape(sparse, shape=(12,))
|
||||
return sparse.values
|
||||
|
||||
error = gradient_checker.max_error(
|
||||
*gradient_checker.compute_gradient(sparse_tensor_reshape, [x]))
|
||||
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
def testWithStaticShape(self):
|
||||
size = (2, 3)
|
||||
constant = constant_op.constant(2.0, shape=size, name="const")
|
||||
|
Loading…
x
Reference in New Issue
Block a user