Implement a cumulative log-sum-exp operation.

PiperOrigin-RevId: 259315584
This commit is contained in:
A. Unique TensorFlower 2019-07-22 06:32:13 -07:00 committed by TensorFlower Gardener
parent 384e7f8c86
commit 68595d89ce
16 changed files with 396 additions and 6 deletions

View File

@ -0,0 +1,50 @@
op {
graph_op_name: "CumulativeLogsumexp"
visibility: HIDDEN
in_arg {
name: "x"
description: <<END
A `Tensor`. Must be one of the following types: `float16`, `float32`, `float64`.
END
}
in_arg {
name: "axis"
description: <<END
A `Tensor` of type `int32` (default: 0). Must be in the range
`[-rank(x), rank(x))`.
END
}
attr {
name: "exclusive"
description: <<END
If `True`, perform exclusive cumulative log-sum-exp.
END
}
attr {
name: "reverse"
description: <<END
A `bool` (default: False).
END
}
summary: "Compute the cumulative product of the tensor `x` along `axis`."
description: <<END
By default, this op performs an inclusive cumulative log-sum-exp,
which means that the first
element of the input is identical to the first element of the output:
```python
tf.math.cumulative_logsumexp([a, b, c]) # => [a, log(exp(a) + exp(b)), log(exp(a) + exp(b) + exp(c))]
```
By setting the `exclusive` kwarg to `True`, an exclusive cumulative log-sum-exp is
performed instead:
```python
tf.cumulative_logsumexp([a, b, c], exclusive=True) # => [-inf, a, log(exp(a) * exp(b))]
```
Note that the neutral element of the log-sum-exp operation is `-inf`,
however, for performance reasons, the minimal value representable by the
floating point type is used instead.
By setting the `reverse` kwarg to `True`, the cumulative log-sum-exp is performed in the
opposite direction.
END
}

View File

@ -18,6 +18,9 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/scan_ops.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -25,10 +28,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/scan_ops.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@ -107,8 +106,12 @@ namespace functor {
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
DECLARE_FOR_ALL_REDUCERS(int32);
DECLARE_FOR_ALL_REDUCERS(int64);
#undef DECLARE_FOR_ALL_REDUCERS
#define DECLARE_FOR_LOGSUMEXP_REDUCER(T) DECLARE(LogSumExpReducer<T>, T);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_LOGSUMEXP_REDUCER)
#undef DECLARE_FOR_LOGSUMEXP_REDUCER
#undef DECLARE
} // namespace functor
@ -192,4 +195,31 @@ REGISTER_GPU_KERNELS(int64);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_CUMLOGSUMEXP_KERNEL(device, device_type, type, type_idx) \
REGISTER_KERNEL_BUILDER( \
Name("CumulativeLogsumexp") \
.Device(device) \
.TypeConstraint<type>("T") \
.TypeConstraint<type_idx>("Tidx") \
.HostMemory("axis"), \
ScanOp<device_type, type, functor::LogSumExpReducer<type>, type_idx>)
#define REGISTER_CPU_KERNELS(type) \
REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_CPU, CPUDevice, type, int32) \
REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_CPU, CPUDevice, type, int64)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_GPU, GPUDevice, type, int32) \
REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_GPU, GPUDevice, type, int64)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_CUMLOGSUMEXP_KERNEL
} // namespace tensorflow

View File

@ -40,6 +40,41 @@ struct Scan {
}
};
template <typename T>
struct LogSumExp {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
const T& b) const {
Eigen::internal::scalar_sum_op<T> sum_op;
Eigen::internal::scalar_exp_op<T> exp_op;
Eigen::internal::scalar_log_op<T> log_op;
Eigen::internal::scalar_max_op<T> max_op;
Eigen::internal::scalar_min_op<T> min_op;
Eigen::internal::scalar_log1p_op<T> log1p_op;
Eigen::internal::scalar_difference_op<T> diff_op;
auto mi = min_op(a, b);
auto ma = max_op(a, b);
return sum_op(log1p_op(exp_op(diff_op(mi, ma))), ma);
}
};
template <typename T>
struct LogSumExpReducer {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
LogSumExp<T> logsumexp;
*accum = logsumexp(*accum, t);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return Eigen::NumTraits<T>::lowest();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum;
}
};
} // namespace functor
} // namespace tensorflow

View File

@ -143,9 +143,16 @@ struct IsProd {
std::is_same<Op, Eigen::internal::ProdReducer<T>>::value);
};
template <typename T, typename Op>
struct IsLogSumExp {
constexpr static bool value = (std::is_same<Op, LogSumExp<T>>::value ||
std::is_same<Op, LogSumExpReducer<T>>::value);
};
template <typename T, typename Op>
struct IdentityValue {
static_assert(IsSum<T, Op>::value || IsProd<T, Op>::value,
static_assert(IsSum<T, Op>::value || IsProd<T, Op>::value ||
IsLogSumExp<T, Op>::value,
"IdentityValue not yet defined for this type.");
template <typename U = T, typename OpCopy = Op>
@ -159,6 +166,13 @@ struct IdentityValue {
typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) {
return t;
}
template <typename U = T, typename OpCopy = Op>
__host__ __device__ U
operator()(typename std::enable_if<IsLogSumExp<U, OpCopy>::value, U>::type t =
U(Eigen::NumTraits<U>::lowest())) {
return t;
}
};
// Each block is mapped to one sequence. A contiguous range is mapped to the
@ -311,6 +325,16 @@ struct Scan<GPUDevice, Eigen::internal::ProdReducer<T>, T> {
}
};
template <typename T>
struct Scan<GPUDevice, LogSumExpReducer<T>, T> {
void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
typename TTypes<T, 3>::Tensor out,
const LogSumExpReducer<T>& reducer, const bool reverse,
const bool exclusive) {
LaunchScan<T, LogSumExp<T>>(d, in, out, LogSumExp<T>(), reverse, exclusive);
}
};
} // namespace functor
} // end namespace tensorflow

View File

@ -26,6 +26,8 @@ template struct functor::Scan<GpuDevice, Eigen::internal::SumReducer<double>,
double>;
template struct functor::Scan<GpuDevice, Eigen::internal::ProdReducer<double>,
double>;
template struct functor::Scan<GpuDevice, functor::LogSumExpReducer<double>,
double>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -26,6 +26,8 @@ template struct functor::Scan<GpuDevice, Eigen::internal::SumReducer<float>,
float>;
template struct functor::Scan<GpuDevice, Eigen::internal::ProdReducer<float>,
float>;
template struct functor::Scan<GpuDevice, functor::LogSumExpReducer<float>,
float>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -26,6 +26,8 @@ template struct functor::Scan<
GpuDevice, Eigen::internal::SumReducer<Eigen::half>, Eigen::half>;
template struct functor::Scan<
GpuDevice, Eigen::internal::ProdReducer<Eigen::half>, Eigen::half>;
template struct functor::Scan<GpuDevice, functor::LogSumExpReducer<Eigen::half>,
Eigen::half>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -1597,6 +1597,16 @@ REGISTER_OP("Cumprod")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("CumulativeLogsumexp")
.Input("x : T")
.Input("axis: Tidx")
.Attr("exclusive: bool = false")
.Attr("reverse: bool = false")
.Output("out: T")
.Attr("T: {float16, float32, float64}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("QuantizedMatMul")
.Input("a: T1")
.Input("b: T2")

View File

@ -316,6 +316,22 @@ cuda_py_test(
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "cumulative_logsumexp_test",
size = "medium",
srcs = ["cumulative_logsumexp_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:map_fn",
"//tensorflow/python:array_ops",
],
xla_enable_strict_auto_jit = True,
)
tf_py_test(
name = "decode_csv_op_test",
size = "small",

View File

@ -0,0 +1,114 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for cumulative_logsumexp op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class CumulativeLogsumexpTest(test.TestCase):
valid_dtypes = [dtypes.float32, dtypes.float64]
def _computeLogSumExp(self, x, **kwargs):
result_naive = math_ops.cumsum(math_ops.exp(x), **kwargs)
result_fused = math_ops.exp(math_ops.cumulative_logsumexp(x, **kwargs))
return result_naive, result_fused
def _testLogSumExp(self, x, dtype=dtypes.float32, use_gpu=False, **kwargs):
with self.cached_session(use_gpu=use_gpu):
x = ops.convert_to_tensor(x, dtype=dtype)
result_naive, result_fused = self.evaluate(
self._computeLogSumExp(x, **kwargs))
self.assertAllClose(result_naive, result_fused)
def _testLogSumExpAllArgs(self, x, axis=0, use_gpu=False):
for dtype in self.valid_dtypes:
for reverse in (True, False):
for exclusive in (True, False):
self._testLogSumExp(
x, dtype=dtype, use_gpu=use_gpu,
reverse=reverse, exclusive=exclusive,
axis=axis)
def test1D(self):
x = np.arange(10) / 10.0 - 0.5
self._testLogSumExpAllArgs(x, use_gpu=False)
self._testLogSumExpAllArgs(x, use_gpu=True)
def test2D(self):
x = np.reshape(np.arange(20) / 20.0 - 0.5, (2, 10))
for axis in (-2, -1, 0, 1):
self._testLogSumExpAllArgs(x, axis=axis, use_gpu=False)
self._testLogSumExpAllArgs(x, axis=axis, use_gpu=True)
def _testGradient(self, x, use_gpu=False, **kwargs):
with self.cached_session(use_gpu=use_gpu):
x = ops.convert_to_tensor(x, dtype=dtypes.float64)
grad_naive_theoretical, _ = gradient_checker_v2.compute_gradient(
lambda y: math_ops.cumsum(math_ops.exp(y), **kwargs), [x])
grad_fused_theoretical, _ = gradient_checker_v2.compute_gradient(
lambda y: math_ops.exp(math_ops.cumulative_logsumexp(y, **kwargs)),
[x])
self.assertAllClose(grad_fused_theoretical, grad_naive_theoretical)
def testGradient(self):
for reverse in (True, False):
for exclusive in (True, False):
x = np.arange(10) / 10.0 - 0.5
self._testGradient(x, use_gpu=False,
reverse=reverse, exclusive=exclusive)
self._testGradient(x, use_gpu=True,
reverse=reverse, exclusive=exclusive)
def _logSumExpMap(self, x):
return map_fn.map_fn(
lambda i: math_ops.reduce_logsumexp(x[:i + 1]),
math_ops.range(array_ops.shape(x)[0]),
dtype=x.dtype)
def test1DLarge(self):
# This test ensures that the operation is correct even when the naive
# implementation would overflow.
x_np = np.arange(20) * 20.0
for use_gpu in (True, False):
with self.cached_session(use_gpu=use_gpu):
x_tf = ops.convert_to_tensor(x_np, dtype=dtypes.float32)
result_fused = self.evaluate(math_ops.cumulative_logsumexp(x_tf))
result_map = self.evaluate(self._logSumExpMap(x_tf))
self.assertAllClose(result_fused, result_map)
if __name__ == '__main__':
test.main()

View File

@ -1641,6 +1641,40 @@ def _CumprodGrad(op, grad):
return [out / x, None]
@ops.RegisterGradient("CumulativeLogsumexp")
def _CumulativeLogsumexpGrad(op, grad):
x = op.inputs[0]
axis = op.inputs[1]
cumulative_logsumexp = op.outputs[0]
exclusive = op.get_attr("exclusive")
reverse = op.get_attr("reverse")
# Split the incoming gradient into positive and negative part
# in order to take logs. This is required for stable results.
log_grad_positive = array_ops.where_v2(
math_ops.greater(grad, 0),
math_ops.log(grad),
grad.dtype.min)
log_grad_negative = array_ops.where_v2(
math_ops.less(grad, 0),
math_ops.log(-grad),
grad.dtype.min)
output_pos = math_ops.exp(
math_ops.cumulative_logsumexp(
log_grad_positive - cumulative_logsumexp,
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
output_neg = math_ops.exp(
math_ops.cumulative_logsumexp(
log_grad_negative - cumulative_logsumexp,
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
return [output_pos - output_neg, None]
@ops.RegisterGradient("NextAfter")
def _NextAfterGrad(op, grad):
"""Returns gradient of nextafter(x1, x2) with respect to x1 and x2."""

View File

@ -3297,6 +3297,61 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
x, axis, exclusive=exclusive, reverse=reverse, name=name)
@tf_export("math.cumulative_logsumexp", v1=["math.cumulative_logsumexp"])
def cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative log-sum-exp of the tensor `x` along `axis`.
By default, this op performs an inclusive cumulative log-sum-exp, which means
that the first element of the input is identical to the first element of
the output.
This operation is significantly more numerically stable than the equivalent
tensorflow operation `tf.math.log(tf.math.cumsum(tf.math.exp(x)))`, although
computes the same result given infinite numerical precision. However, note
that in some cases, it may be less stable than `tf.math.reduce_logsumexp`
for a given element, as it applies the "log-sum-exp trick" in a different
way.
More precisely, where `tf.math.reduce_logsumexp` uses the following trick:
```
log(sum(exp(x))) == log(sum(exp(x - max(x)))) + max(x)
```
it cannot be directly used here as there is no fast way of applying it
to each prefix `x[:i]`. Instead, this function implements a prefix
scan using pairwise log-add-exp, which is a commutative and associative
(up to floating point precision) operator:
```
log_add_exp(x, y) = log(exp(x) + exp(y))
= log(1 + exp(min(x, y) - max(x, y))) + max(x, y)
```
However, reducing using the above operator leads to a different computation
tree (logs are taken repeatedly instead of only at the end), and the maximum
is only computed pairwise instead of over the entire prefix. In general, this
leads to a different and slightly less precise computation.
Args:
x: A `Tensor`. Must be one of the following types: `float16`, `float32`,
`float64`.
axis: A `Tensor` of type `int32` or `int64` (default: 0). Must be in the
range `[-rank(x), rank(x))`.
exclusive: If `True`, perform exclusive cumulative log-sum-exp.
reverse: If `True`, performs the cumulative log-sum-exp in the reverse
direction.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same shape and type as `x`.
"""
with ops.name_scope(name, "CumulativeLogsumexp", [x]) as name:
x = ops.convert_to_tensor(x, name="x")
return gen_math_ops.cumulative_logsumexp(
x, axis, exclusive=exclusive, reverse=reverse, name=name)
@tf_export("math.conj", v1=["math.conj", "conj"])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints("conj")

View File

@ -112,6 +112,10 @@ tf_module {
name: "cumsum"
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
}
member_method {
name: "cumulative_logsumexp"
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
}
member_method {
name: "digamma"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -844,6 +844,10 @@ tf_module {
name: "Cumsum"
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "CumulativeLogsumexp"
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "DataFormatDimMap"
argspec: "args=[\'x\', \'src_format\', \'dst_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'NCHW\', \'None\'], "

View File

@ -112,6 +112,10 @@ tf_module {
name: "cumsum"
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
}
member_method {
name: "cumulative_logsumexp"
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
}
member_method {
name: "digamma"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -844,6 +844,10 @@ tf_module {
name: "Cumsum"
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "CumulativeLogsumexp"
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
}
member_method {
name: "DataFormatDimMap"
argspec: "args=[\'x\', \'src_format\', \'dst_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'NCHW\', \'None\'], "