Fix the gradient of x**0 wrt the base (it's now zero instead of NaN) for real dtypes when x is 0 or NaN.
We already have similar special casing for the gradient wrt the power when the base is zero Fixes #35011. PiperOrigin-RevId: 309789626 Change-Id: I446279fd9215c8abe1ec329f7b9b25f05772e276
This commit is contained in:
parent
9b2dc5ff87
commit
17ba8a6ed4
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import backprop
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -29,7 +28,6 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker
|
||||||
from tensorflow.python.ops import gradient_checker_v2
|
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||||
@ -761,7 +759,7 @@ class BinaryOpTest(test.TestCase):
|
|||||||
ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
|
ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testZeroBasePowGrad(self):
|
def testZeroPowGrad(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
for dtype in (np.float16, np.float32, np.float64, np.complex64,
|
for dtype in (np.float16, np.float32, np.float64, np.complex64,
|
||||||
np.complex128):
|
np.complex128):
|
||||||
@ -771,43 +769,6 @@ class BinaryOpTest(test.TestCase):
|
|||||||
error = gradient_checker.compute_gradient_error(y, [], z, [])
|
error = gradient_checker.compute_gradient_error(y, [], z, [])
|
||||||
self.assertEqual(error, 0)
|
self.assertEqual(error, 0)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testZeroPowerPowGrad(self):
|
|
||||||
# Tests for 0. ** 0.'s gradient with respect to the base for real dtypes
|
|
||||||
# only. For complex types 0. ** 0. itself isn't well defined, so we'd get a
|
|
||||||
# non-deterministic smattering of NaNs in the test.
|
|
||||||
|
|
||||||
# pylint: disable=cell-var-from-loop
|
|
||||||
for dtype in (np.float16, np.float32, np.float64):
|
|
||||||
sym_jac, num_jac = gradient_checker_v2.compute_gradient(
|
|
||||||
lambda x: math_ops.pow(x, 0.),
|
|
||||||
[constant_op.constant([-1., 0., 1.], dtype=dtype)])
|
|
||||||
self.assertAllClose(sym_jac, num_jac)
|
|
||||||
power = constant_op.constant([0., 0., 0.], dtype=dtype)
|
|
||||||
sym_jac, num_jac = gradient_checker_v2.compute_gradient(
|
|
||||||
lambda x: math_ops.pow(x, power),
|
|
||||||
[constant_op.constant([-1., 0., 1.], dtype=dtype)])
|
|
||||||
self.assertAllClose(sym_jac, num_jac)
|
|
||||||
with backprop.GradientTape() as tape:
|
|
||||||
x = constant_op.constant(float("NaN"), dtype=dtype)
|
|
||||||
tape.watch(x)
|
|
||||||
y = x ** 0.
|
|
||||||
self.assertAllClose(0., tape.gradient(y, x))
|
|
||||||
# pylint: enable=cell-var-from-loop
|
|
||||||
|
|
||||||
x = constant_op.constant(0.)
|
|
||||||
with backprop.GradientTape(persistent=True) as tape:
|
|
||||||
tape.watch(x)
|
|
||||||
y = math_ops.pow(x, 2.)
|
|
||||||
x_g = tape.gradient(y, x)
|
|
||||||
self.assertAllClose(2. * 0. ** 1., x_g)
|
|
||||||
x_gg = tape.gradient(x_g, x)
|
|
||||||
self.assertAllClose(2. * 0. ** 0., x_gg)
|
|
||||||
x_ggg = tape.gradient(x_gg, x)
|
|
||||||
self.assertAllClose(0., x_ggg)
|
|
||||||
# Note that higher-order gradients currently return NaN since backprop
|
|
||||||
# isn't very smart.
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testComplexPowGrad(self):
|
def testComplexPowGrad(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
|
@ -1338,28 +1338,7 @@ def _PowGrad(op, grad):
|
|||||||
y):
|
y):
|
||||||
x = math_ops.conj(x)
|
x = math_ops.conj(x)
|
||||||
y = math_ops.conj(y)
|
y = math_ops.conj(y)
|
||||||
if x.dtype.is_complex:
|
return grad * y * math_ops.pow(x, y - 1), None
|
||||||
dx = gen_math_ops.mul(math_ops.pow(x, y - 1), y)
|
|
||||||
else:
|
|
||||||
# Define x ** 0 to have a derivative of zero with respect to x for real
|
|
||||||
# dtypes. mul_no_nan replaces its outputs with 0 where `y` is 0. It is a
|
|
||||||
# more efficient equivalent of this subgraph:
|
|
||||||
#
|
|
||||||
# raw_derivative = y * tf.pow(x, y - 1)
|
|
||||||
# dx = tf.where(tf.broadcast_to(tf.not_equal(y, 0),
|
|
||||||
# tf.shape(raw_derivative)),
|
|
||||||
# raw_derivative,
|
|
||||||
# tf.zeros_like(raw_derivative))
|
|
||||||
#
|
|
||||||
# It does not suppress all NaNs. mul_no_nan only differs from regular
|
|
||||||
# multiplication when `y = 0`, in which case `tf.pow(x, y - 1)` would
|
|
||||||
# ordinarily be NaN if `x = 0` (leading to a NaN gradient). Small
|
|
||||||
# perturbations to `x` would not affect the result of `x ** 0`, meaning
|
|
||||||
# the correct gradient is 0 whenever `y = 0`. This special-casing also
|
|
||||||
# defines a gradient for `x` when `x = NaN` and `y = 0`; this follows
|
|
||||||
# from the convention that `NaN ** 0 = 1`.
|
|
||||||
dx = gen_math_ops.mul_no_nan(math_ops.pow(x, y - 1), y)
|
|
||||||
return grad * dx, None
|
|
||||||
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# No gradient skipping, so do the full gradient computation
|
# No gradient skipping, so do the full gradient computation
|
||||||
@ -1371,13 +1350,7 @@ def _PowGrad(op, grad):
|
|||||||
y = math_ops.conj(y)
|
y = math_ops.conj(y)
|
||||||
|
|
||||||
if skip_input_indices is None or 0 not in skip_input_indices:
|
if skip_input_indices is None or 0 not in skip_input_indices:
|
||||||
if x.dtype.is_complex:
|
gx = grad * y * math_ops.pow(x, y - 1)
|
||||||
dx = gen_math_ops.mul(math_ops.pow(x, y - 1), y)
|
|
||||||
else:
|
|
||||||
# Define x ** 0 to have a derivative of zero with respect to x for real
|
|
||||||
# dtypes. mul_no_nan replaces its outputs with 0 where `y` is 0.
|
|
||||||
dx = gen_math_ops.mul_no_nan(math_ops.pow(x, y - 1), y)
|
|
||||||
gx = grad * dx
|
|
||||||
if must_reduce_x:
|
if must_reduce_x:
|
||||||
gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
|
gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user