Adding reduce_logsumexp to core math_ops.py.
This function computes log(sum(exp(elements across dimensions of a tensor))), avoiding underflow/overflows in most cases. Change: 131276538
This commit is contained in:
parent
55d8c20513
commit
e1a89ef3f1
@ -150,6 +150,7 @@ common math computations that reduce various dimensions of a tensor.
|
||||
@@reduce_mean
|
||||
@@reduce_all
|
||||
@@reduce_any
|
||||
@@reduce_logsumexp
|
||||
|
||||
@@accumulate_n
|
||||
|
||||
@ -1250,6 +1251,56 @@ def reduce_any(input_tensor, reduction_indices=None, keep_dims=False,
|
||||
keep_dims, name=name)
|
||||
|
||||
|
||||
def reduce_logsumexp(input_tensor, reduction_indices=None, keep_dims=False,
|
||||
name=None):
|
||||
"""Computes log(sum(exp(elements across dimensions of a tensor))).
|
||||
|
||||
Reduces `input_tensor` along the dimensions given in `reduction_indices`.
|
||||
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
|
||||
entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
|
||||
are retained with length 1.
|
||||
|
||||
If `reduction_indices` has no entries, all dimensions are reduced, and a
|
||||
tensor with a single element is returned.
|
||||
|
||||
This funciton is more numerically stable than log(sum(exp(input))). It avoids
|
||||
overflows caused by taking the exp of large inputs and underflows caused by
|
||||
taking the log of small inputs.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# 'x' is [[0, 0, 0]]
|
||||
# [0, 0, 0]]
|
||||
tf.reduce_logsumexp(x) ==> log(6)
|
||||
tf.reduce_logsumexp(x, 0) ==> [log(2), log(2), log(2)]
|
||||
tf.reduce_logsumexp(x, 1) ==> [log(3), log(3)]
|
||||
tf.reduce_logsumexp(x, 1, keep_dims=True) ==> [[log(3)], [log(3)]]
|
||||
tf.reduce_logsumexp(x, [0, 1]) ==> log(6)
|
||||
```
|
||||
|
||||
Args:
|
||||
input_tensor: The tensor to reduce. Should have numeric type.
|
||||
reduction_indices: The dimensions to reduce. If `None` (the defaut),
|
||||
reduces all dimensions.
|
||||
keep_dims: If true, retains reduced dimensions with length 1.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
The reduced tensor.
|
||||
"""
|
||||
with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
|
||||
my_max = array_ops.stop_gradient(
|
||||
reduce_max(input_tensor, reduction_indices, keep_dims=True))
|
||||
result = gen_math_ops.log(reduce_sum(
|
||||
gen_math_ops.exp(input_tensor - my_max),
|
||||
reduction_indices,
|
||||
keep_dims=True)) + my_max
|
||||
if not keep_dims:
|
||||
result = array_ops.squeeze(result, reduction_indices)
|
||||
return result
|
||||
|
||||
|
||||
def trace(x, name=None):
|
||||
""" Compute the trace of a tensor `x`.
|
||||
|
||||
|
@ -47,6 +47,70 @@ class ReduceTest(test_util.TensorFlowTestCase):
|
||||
math_ops.reduce_sum(x, axis)
|
||||
|
||||
|
||||
class LogSumExpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testReduceLogSumExp(self):
|
||||
for dtype in [np.float16, np.float32, np.double]:
|
||||
x_np = np.random.rand(5, 5).astype(dtype)
|
||||
with self.test_session():
|
||||
y_tf_np = math_ops.reduce_logsumexp(x_np).eval()
|
||||
y_np = log(np.sum(exp(x_np)))
|
||||
self.assertAllClose(y_tf_np, y_np)
|
||||
|
||||
def testReductionIndices(self):
|
||||
for dtype in [np.float16, np.float32, np.double]:
|
||||
x_np = np.random.rand(5, 5).astype(dtype)
|
||||
with self.test_session():
|
||||
y_tf = math_ops.reduce_logsumexp(x_np, reduction_indices=[0])
|
||||
y_np = log(np.sum(exp(x_np), axis=0))
|
||||
self.assertShapeEqual(y_np, y_tf)
|
||||
y_tf_np = y_tf.eval()
|
||||
self.assertAllClose(y_tf_np, y_np)
|
||||
|
||||
def testKeepDims(self):
|
||||
for dtype in [np.float16, np.float32, np.double]:
|
||||
x_np = np.random.rand(5, 5).astype(dtype)
|
||||
with self.test_session():
|
||||
y_tf_np = math_ops.reduce_logsumexp(x_np, keep_dims=True).eval()
|
||||
self.assertEqual(y_tf_np.ndim, x_np.ndim)
|
||||
y_np = log(np.sum(exp(x_np), keepdims=True))
|
||||
self.assertAllClose(y_tf_np, y_np)
|
||||
|
||||
def testOverflow(self):
|
||||
x = [1000, 1001, 1002, 1003]
|
||||
for dtype in [np.float16, np.float32, np.double]:
|
||||
x_np = np.array(x, dtype=dtype)
|
||||
max_np = np.max(x_np)
|
||||
with self.assertRaisesRegexp(RuntimeWarning,
|
||||
"overflow encountered in exp"):
|
||||
out = log(np.sum(exp(x_np)))
|
||||
if out == np.inf:
|
||||
raise RuntimeWarning("overflow encountered in exp")
|
||||
|
||||
with self.test_session():
|
||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||
y_tf_np = math_ops.reduce_logsumexp(x_tf).eval()
|
||||
y_np = log(np.sum(exp(x_np - max_np))) + max_np
|
||||
self.assertAllClose(y_tf_np, y_np)
|
||||
|
||||
def testUnderflow(self):
|
||||
x = [-1000, -1001, -1002, -1003]
|
||||
for dtype in [np.float16, np.float32, np.double]:
|
||||
x_np = np.array(x, dtype=dtype)
|
||||
max_np = np.max(x_np)
|
||||
with self.assertRaisesRegexp(RuntimeWarning,
|
||||
"divide by zero encountered in log"):
|
||||
out = log(np.sum(exp(x_np)))
|
||||
if out == -np.inf:
|
||||
raise RuntimeWarning("divide by zero encountered in log")
|
||||
|
||||
with self.test_session():
|
||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||
y_tf_np = math_ops.reduce_logsumexp(x_tf).eval()
|
||||
y_np = log(np.sum(exp(x_np - max_np))) + max_np
|
||||
self.assertAllClose(y_tf_np, y_np)
|
||||
|
||||
|
||||
class RoundTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testRounding(self):
|
||||
|
Loading…
Reference in New Issue
Block a user