Fix tf.angle for real tensors. For x < 0, the angle is pi, not zero.

PiperOrigin-RevId: 309471787
Change-Id: If49367d3f6a56bd5fa361d8c23627e631c9feb8f
This commit is contained in:
A. Unique TensorFlower 2020-05-01 14:13:34 -07:00 committed by TensorFlower Gardener
parent 0f9d5fac64
commit 88a2edd2c7
2 changed files with 20 additions and 11 deletions

View File

@ -1053,19 +1053,27 @@ class ComplexMakeRealImagTest(test.TestCase):
self.assertAllClose(np_angle, tf_angle_val)
self.assertShapeEqual(np_angle, tf_angle)
def testAngle64(self):
real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32)
imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32)
cplx = real + 1j * imag
self._compareAngle(cplx, use_gpu=False)
self._compareAngle(cplx, use_gpu=True)
def testAngle(self):
real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float64)
imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float64)
cplx = real + 1j * imag
mag = np.random.rand(10).astype(np.float32)
angle = (2 * np.pi * np.arange(10) / 10.).astype(np.float32)
cplx = mag * np.exp(1j * angle)
cplx = np.append(cplx, [1., 1.j, -1., -1.j])
self._compareAngle(cplx, use_gpu=False)
self._compareAngle(cplx, use_gpu=True)
real = (np.arange(-2, 2) / 2.).astype(np.float64)
self._compareAngle(real, use_gpu=False)
self._compareAngle(real, use_gpu=True)
def testAngle64(self):
mag = np.random.rand(10).astype(np.float64)
angle = (2 * np.pi * np.arange(10) / 100.).astype(np.float64)
cplx = mag * np.exp(1j * angle)
cplx = np.append(cplx, [1., 1.j, -1., -1.j])
self._compareAngle(cplx, use_gpu=False)
self._compareAngle(cplx, use_gpu=True)
real = (np.arange(-2, 2) / 2.).astype(np.float64)
self._compareAngle(real, use_gpu=False)
self._compareAngle(real, use_gpu=True)
@test_util.run_deprecated_v1
def testRealReal(self):

View File

@ -804,7 +804,8 @@ def angle(input, name=None):
if input.dtype.is_complex:
return gen_math_ops.angle(input, Tout=input.dtype.real_dtype, name=name)
else:
return array_ops.zeros_like(input)
return array_ops.where(input < 0, np.pi * array_ops.ones_like(input),
array_ops.zeros_like(input))
# pylint: enable=redefined-outer-name,redefined-builtin