diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py index 2dd733a78b9..4c6a41bf205 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors @@ -29,7 +28,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_grad # pylint: disable=unused-import @@ -761,7 +759,7 @@ class BinaryOpTest(test.TestCase): ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]])) @test_util.run_deprecated_v1 - def testZeroBasePowGrad(self): + def testZeroPowGrad(self): with self.cached_session(): for dtype in (np.float16, np.float32, np.float64, np.complex64, np.complex128): @@ -771,43 +769,6 @@ class BinaryOpTest(test.TestCase): error = gradient_checker.compute_gradient_error(y, [], z, []) self.assertEqual(error, 0) - @test_util.run_in_graph_and_eager_modes - def testZeroPowerPowGrad(self): - # Tests for 0. ** 0.'s gradient with respect to the base for real dtypes - # only. For complex types 0. ** 0. itself isn't well defined, so we'd get a - # non-deterministic smattering of NaNs in the test. - - # pylint: disable=cell-var-from-loop - for dtype in (np.float16, np.float32, np.float64): - sym_jac, num_jac = gradient_checker_v2.compute_gradient( - lambda x: math_ops.pow(x, 0.), - [constant_op.constant([-1., 0., 1.], dtype=dtype)]) - self.assertAllClose(sym_jac, num_jac) - power = constant_op.constant([0., 0., 0.], dtype=dtype) - sym_jac, num_jac = gradient_checker_v2.compute_gradient( - lambda x: math_ops.pow(x, power), - [constant_op.constant([-1., 0., 1.], dtype=dtype)]) - self.assertAllClose(sym_jac, num_jac) - with backprop.GradientTape() as tape: - x = constant_op.constant(float("NaN"), dtype=dtype) - tape.watch(x) - y = x ** 0. - self.assertAllClose(0., tape.gradient(y, x)) - # pylint: enable=cell-var-from-loop - - x = constant_op.constant(0.) - with backprop.GradientTape(persistent=True) as tape: - tape.watch(x) - y = math_ops.pow(x, 2.) - x_g = tape.gradient(y, x) - self.assertAllClose(2. * 0. ** 1., x_g) - x_gg = tape.gradient(x_g, x) - self.assertAllClose(2. * 0. ** 0., x_gg) - x_ggg = tape.gradient(x_gg, x) - self.assertAllClose(0., x_ggg) - # Note that higher-order gradients currently return NaN since backprop - # isn't very smart. - @test_util.run_deprecated_v1 def testComplexPowGrad(self): with self.cached_session(): diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 345ee68b995..8ce35de006a 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -1338,28 +1338,7 @@ def _PowGrad(op, grad): y): x = math_ops.conj(x) y = math_ops.conj(y) - if x.dtype.is_complex: - dx = gen_math_ops.mul(math_ops.pow(x, y - 1), y) - else: - # Define x ** 0 to have a derivative of zero with respect to x for real - # dtypes. mul_no_nan replaces its outputs with 0 where `y` is 0. It is a - # more efficient equivalent of this subgraph: - # - # raw_derivative = y * tf.pow(x, y - 1) - # dx = tf.where(tf.broadcast_to(tf.not_equal(y, 0), - # tf.shape(raw_derivative)), - # raw_derivative, - # tf.zeros_like(raw_derivative)) - # - # It does not suppress all NaNs. mul_no_nan only differs from regular - # multiplication when `y = 0`, in which case `tf.pow(x, y - 1)` would - # ordinarily be NaN if `x = 0` (leading to a NaN gradient). Small - # perturbations to `x` would not affect the result of `x ** 0`, meaning - # the correct gradient is 0 whenever `y = 0`. This special-casing also - # defines a gradient for `x` when `x = NaN` and `y = 0`; this follows - # from the convention that `NaN ** 0 = 1`. - dx = gen_math_ops.mul_no_nan(math_ops.pow(x, y - 1), y) - return grad * dx, None + return grad * y * math_ops.pow(x, y - 1), None except AttributeError: # No gradient skipping, so do the full gradient computation @@ -1371,13 +1350,7 @@ def _PowGrad(op, grad): y = math_ops.conj(y) if skip_input_indices is None or 0 not in skip_input_indices: - if x.dtype.is_complex: - dx = gen_math_ops.mul(math_ops.pow(x, y - 1), y) - else: - # Define x ** 0 to have a derivative of zero with respect to x for real - # dtypes. mul_no_nan replaces its outputs with 0 where `y` is 0. - dx = gen_math_ops.mul_no_nan(math_ops.pow(x, y - 1), y) - gx = grad * dx + gx = grad * y * math_ops.pow(x, y - 1) if must_reduce_x: gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx) else: