Fix the gradient of x**0 wrt the base (it's now zero instead of NaN) for real dtypes when x is 0 or NaN.
We already have similar special casing for the gradient wrt the power when the base is zero Fixes #35011. PiperOrigin-RevId: 309789626 Change-Id: I446279fd9215c8abe1ec329f7b9b25f05772e276
This commit is contained in:
parent
9b2dc5ff87
commit
17ba8a6ed4
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.framework import errors
|
||||
|
@ -29,7 +28,6 @@ from tensorflow.python.framework import ops
|
|||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||
|
@ -761,7 +759,7 @@ class BinaryOpTest(test.TestCase):
|
|||
ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeroBasePowGrad(self):
|
||||
def testZeroPowGrad(self):
|
||||
with self.cached_session():
|
||||
for dtype in (np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128):
|
||||
|
@ -771,43 +769,6 @@ class BinaryOpTest(test.TestCase):
|
|||
error = gradient_checker.compute_gradient_error(y, [], z, [])
|
||||
self.assertEqual(error, 0)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testZeroPowerPowGrad(self):
|
||||
# Tests for 0. ** 0.'s gradient with respect to the base for real dtypes
|
||||
# only. For complex types 0. ** 0. itself isn't well defined, so we'd get a
|
||||
# non-deterministic smattering of NaNs in the test.
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
for dtype in (np.float16, np.float32, np.float64):
|
||||
sym_jac, num_jac = gradient_checker_v2.compute_gradient(
|
||||
lambda x: math_ops.pow(x, 0.),
|
||||
[constant_op.constant([-1., 0., 1.], dtype=dtype)])
|
||||
self.assertAllClose(sym_jac, num_jac)
|
||||
power = constant_op.constant([0., 0., 0.], dtype=dtype)
|
||||
sym_jac, num_jac = gradient_checker_v2.compute_gradient(
|
||||
lambda x: math_ops.pow(x, power),
|
||||
[constant_op.constant([-1., 0., 1.], dtype=dtype)])
|
||||
self.assertAllClose(sym_jac, num_jac)
|
||||
with backprop.GradientTape() as tape:
|
||||
x = constant_op.constant(float("NaN"), dtype=dtype)
|
||||
tape.watch(x)
|
||||
y = x ** 0.
|
||||
self.assertAllClose(0., tape.gradient(y, x))
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
x = constant_op.constant(0.)
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(x)
|
||||
y = math_ops.pow(x, 2.)
|
||||
x_g = tape.gradient(y, x)
|
||||
self.assertAllClose(2. * 0. ** 1., x_g)
|
||||
x_gg = tape.gradient(x_g, x)
|
||||
self.assertAllClose(2. * 0. ** 0., x_gg)
|
||||
x_ggg = tape.gradient(x_gg, x)
|
||||
self.assertAllClose(0., x_ggg)
|
||||
# Note that higher-order gradients currently return NaN since backprop
|
||||
# isn't very smart.
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testComplexPowGrad(self):
|
||||
with self.cached_session():
|
||||
|
|
|
@ -1338,28 +1338,7 @@ def _PowGrad(op, grad):
|
|||
y):
|
||||
x = math_ops.conj(x)
|
||||
y = math_ops.conj(y)
|
||||
if x.dtype.is_complex:
|
||||
dx = gen_math_ops.mul(math_ops.pow(x, y - 1), y)
|
||||
else:
|
||||
# Define x ** 0 to have a derivative of zero with respect to x for real
|
||||
# dtypes. mul_no_nan replaces its outputs with 0 where `y` is 0. It is a
|
||||
# more efficient equivalent of this subgraph:
|
||||
#
|
||||
# raw_derivative = y * tf.pow(x, y - 1)
|
||||
# dx = tf.where(tf.broadcast_to(tf.not_equal(y, 0),
|
||||
# tf.shape(raw_derivative)),
|
||||
# raw_derivative,
|
||||
# tf.zeros_like(raw_derivative))
|
||||
#
|
||||
# It does not suppress all NaNs. mul_no_nan only differs from regular
|
||||
# multiplication when `y = 0`, in which case `tf.pow(x, y - 1)` would
|
||||
# ordinarily be NaN if `x = 0` (leading to a NaN gradient). Small
|
||||
# perturbations to `x` would not affect the result of `x ** 0`, meaning
|
||||
# the correct gradient is 0 whenever `y = 0`. This special-casing also
|
||||
# defines a gradient for `x` when `x = NaN` and `y = 0`; this follows
|
||||
# from the convention that `NaN ** 0 = 1`.
|
||||
dx = gen_math_ops.mul_no_nan(math_ops.pow(x, y - 1), y)
|
||||
return grad * dx, None
|
||||
return grad * y * math_ops.pow(x, y - 1), None
|
||||
|
||||
except AttributeError:
|
||||
# No gradient skipping, so do the full gradient computation
|
||||
|
@ -1371,13 +1350,7 @@ def _PowGrad(op, grad):
|
|||
y = math_ops.conj(y)
|
||||
|
||||
if skip_input_indices is None or 0 not in skip_input_indices:
|
||||
if x.dtype.is_complex:
|
||||
dx = gen_math_ops.mul(math_ops.pow(x, y - 1), y)
|
||||
else:
|
||||
# Define x ** 0 to have a derivative of zero with respect to x for real
|
||||
# dtypes. mul_no_nan replaces its outputs with 0 where `y` is 0.
|
||||
dx = gen_math_ops.mul_no_nan(math_ops.pow(x, y - 1), y)
|
||||
gx = grad * dx
|
||||
gx = grad * y * math_ops.pow(x, y - 1)
|
||||
if must_reduce_x:
|
||||
gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue