Fix tf.angle for real tensors. For x < 0, the angle is pi, not zero.
PiperOrigin-RevId: 309471787 Change-Id: If49367d3f6a56bd5fa361d8c23627e631c9feb8f
This commit is contained in:
parent
0f9d5fac64
commit
88a2edd2c7
@ -1053,19 +1053,27 @@ class ComplexMakeRealImagTest(test.TestCase):
|
||||
self.assertAllClose(np_angle, tf_angle_val)
|
||||
self.assertShapeEqual(np_angle, tf_angle)
|
||||
|
||||
def testAngle64(self):
|
||||
real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float32)
|
||||
imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float32)
|
||||
cplx = real + 1j * imag
|
||||
self._compareAngle(cplx, use_gpu=False)
|
||||
self._compareAngle(cplx, use_gpu=True)
|
||||
|
||||
def testAngle(self):
|
||||
real = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(np.float64)
|
||||
imag = (np.arange(-3, 3) / 5.).reshape([1, 3, 2]).astype(np.float64)
|
||||
cplx = real + 1j * imag
|
||||
mag = np.random.rand(10).astype(np.float32)
|
||||
angle = (2 * np.pi * np.arange(10) / 10.).astype(np.float32)
|
||||
cplx = mag * np.exp(1j * angle)
|
||||
cplx = np.append(cplx, [1., 1.j, -1., -1.j])
|
||||
self._compareAngle(cplx, use_gpu=False)
|
||||
self._compareAngle(cplx, use_gpu=True)
|
||||
real = (np.arange(-2, 2) / 2.).astype(np.float64)
|
||||
self._compareAngle(real, use_gpu=False)
|
||||
self._compareAngle(real, use_gpu=True)
|
||||
|
||||
def testAngle64(self):
|
||||
mag = np.random.rand(10).astype(np.float64)
|
||||
angle = (2 * np.pi * np.arange(10) / 100.).astype(np.float64)
|
||||
cplx = mag * np.exp(1j * angle)
|
||||
cplx = np.append(cplx, [1., 1.j, -1., -1.j])
|
||||
self._compareAngle(cplx, use_gpu=False)
|
||||
self._compareAngle(cplx, use_gpu=True)
|
||||
real = (np.arange(-2, 2) / 2.).astype(np.float64)
|
||||
self._compareAngle(real, use_gpu=False)
|
||||
self._compareAngle(real, use_gpu=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRealReal(self):
|
||||
|
@ -804,7 +804,8 @@ def angle(input, name=None):
|
||||
if input.dtype.is_complex:
|
||||
return gen_math_ops.angle(input, Tout=input.dtype.real_dtype, name=name)
|
||||
else:
|
||||
return array_ops.zeros_like(input)
|
||||
return array_ops.where(input < 0, np.pi * array_ops.ones_like(input),
|
||||
array_ops.zeros_like(input))
|
||||
|
||||
|
||||
# pylint: enable=redefined-outer-name,redefined-builtin
|
||||
|
Loading…
Reference in New Issue
Block a user