Fix the gradient of x**0 wrt the base (it's now zero instead of NaN) for real dtypes when x is 0 or NaN.

We already have similar special casing for the gradient wrt the power when the base is zero

Fixes #35011.

PiperOrigin-RevId: 309789626
Change-Id: I446279fd9215c8abe1ec329f7b9b25f05772e276
This commit is contained in:
Allen Lavoie 2020-05-04 11:49:15 -07:00 committed by TensorFlower Gardener
parent 9b2dc5ff87
commit 17ba8a6ed4
2 changed files with 3 additions and 69 deletions

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors
@ -29,7 +28,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
@ -761,7 +759,7 @@ class BinaryOpTest(test.TestCase):
ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
@test_util.run_deprecated_v1
def testZeroBasePowGrad(self):
def testZeroPowGrad(self):
with self.cached_session():
for dtype in (np.float16, np.float32, np.float64, np.complex64,
np.complex128):
@ -771,43 +769,6 @@ class BinaryOpTest(test.TestCase):
error = gradient_checker.compute_gradient_error(y, [], z, [])
self.assertEqual(error, 0)
@test_util.run_in_graph_and_eager_modes
def testZeroPowerPowGrad(self):
# Tests for 0. ** 0.'s gradient with respect to the base for real dtypes
# only. For complex types 0. ** 0. itself isn't well defined, so we'd get a
# non-deterministic smattering of NaNs in the test.
# pylint: disable=cell-var-from-loop
for dtype in (np.float16, np.float32, np.float64):
sym_jac, num_jac = gradient_checker_v2.compute_gradient(
lambda x: math_ops.pow(x, 0.),
[constant_op.constant([-1., 0., 1.], dtype=dtype)])
self.assertAllClose(sym_jac, num_jac)
power = constant_op.constant([0., 0., 0.], dtype=dtype)
sym_jac, num_jac = gradient_checker_v2.compute_gradient(
lambda x: math_ops.pow(x, power),
[constant_op.constant([-1., 0., 1.], dtype=dtype)])
self.assertAllClose(sym_jac, num_jac)
with backprop.GradientTape() as tape:
x = constant_op.constant(float("NaN"), dtype=dtype)
tape.watch(x)
y = x ** 0.
self.assertAllClose(0., tape.gradient(y, x))
# pylint: enable=cell-var-from-loop
x = constant_op.constant(0.)
with backprop.GradientTape(persistent=True) as tape:
tape.watch(x)
y = math_ops.pow(x, 2.)
x_g = tape.gradient(y, x)
self.assertAllClose(2. * 0. ** 1., x_g)
x_gg = tape.gradient(x_g, x)
self.assertAllClose(2. * 0. ** 0., x_gg)
x_ggg = tape.gradient(x_gg, x)
self.assertAllClose(0., x_ggg)
# Note that higher-order gradients currently return NaN since backprop
# isn't very smart.
@test_util.run_deprecated_v1
def testComplexPowGrad(self):
with self.cached_session():

View File

@ -1338,28 +1338,7 @@ def _PowGrad(op, grad):
y):
x = math_ops.conj(x)
y = math_ops.conj(y)
if x.dtype.is_complex:
dx = gen_math_ops.mul(math_ops.pow(x, y - 1), y)
else:
# Define x ** 0 to have a derivative of zero with respect to x for real
# dtypes. mul_no_nan replaces its outputs with 0 where `y` is 0. It is a
# more efficient equivalent of this subgraph:
#
# raw_derivative = y * tf.pow(x, y - 1)
# dx = tf.where(tf.broadcast_to(tf.not_equal(y, 0),
# tf.shape(raw_derivative)),
# raw_derivative,
# tf.zeros_like(raw_derivative))
#
# It does not suppress all NaNs. mul_no_nan only differs from regular
# multiplication when `y = 0`, in which case `tf.pow(x, y - 1)` would
# ordinarily be NaN if `x = 0` (leading to a NaN gradient). Small
# perturbations to `x` would not affect the result of `x ** 0`, meaning
# the correct gradient is 0 whenever `y = 0`. This special-casing also
# defines a gradient for `x` when `x = NaN` and `y = 0`; this follows
# from the convention that `NaN ** 0 = 1`.
dx = gen_math_ops.mul_no_nan(math_ops.pow(x, y - 1), y)
return grad * dx, None
return grad * y * math_ops.pow(x, y - 1), None
except AttributeError:
# No gradient skipping, so do the full gradient computation
@ -1371,13 +1350,7 @@ def _PowGrad(op, grad):
y = math_ops.conj(y)
if skip_input_indices is None or 0 not in skip_input_indices:
if x.dtype.is_complex:
dx = gen_math_ops.mul(math_ops.pow(x, y - 1), y)
else:
# Define x ** 0 to have a derivative of zero with respect to x for real
# dtypes. mul_no_nan replaces its outputs with 0 where `y` is 0.
dx = gen_math_ops.mul_no_nan(math_ops.pow(x, y - 1), y)
gx = grad * dx
gx = grad * y * math_ops.pow(x, y - 1)
if must_reduce_x:
gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
else: