Implement a cumulative log-sum-exp operation.
PiperOrigin-RevId: 259315584
This commit is contained in:
parent
384e7f8c86
commit
68595d89ce
@ -0,0 +1,50 @@
|
||||
op {
|
||||
graph_op_name: "CumulativeLogsumexp"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "x"
|
||||
description: <<END
|
||||
A `Tensor`. Must be one of the following types: `float16`, `float32`, `float64`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "axis"
|
||||
description: <<END
|
||||
A `Tensor` of type `int32` (default: 0). Must be in the range
|
||||
`[-rank(x), rank(x))`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "exclusive"
|
||||
description: <<END
|
||||
If `True`, perform exclusive cumulative log-sum-exp.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "reverse"
|
||||
description: <<END
|
||||
A `bool` (default: False).
|
||||
END
|
||||
}
|
||||
summary: "Compute the cumulative product of the tensor `x` along `axis`."
|
||||
description: <<END
|
||||
By default, this op performs an inclusive cumulative log-sum-exp,
|
||||
which means that the first
|
||||
element of the input is identical to the first element of the output:
|
||||
```python
|
||||
tf.math.cumulative_logsumexp([a, b, c]) # => [a, log(exp(a) + exp(b)), log(exp(a) + exp(b) + exp(c))]
|
||||
```
|
||||
|
||||
By setting the `exclusive` kwarg to `True`, an exclusive cumulative log-sum-exp is
|
||||
performed instead:
|
||||
```python
|
||||
tf.cumulative_logsumexp([a, b, c], exclusive=True) # => [-inf, a, log(exp(a) * exp(b))]
|
||||
```
|
||||
Note that the neutral element of the log-sum-exp operation is `-inf`,
|
||||
however, for performance reasons, the minimal value representable by the
|
||||
floating point type is used instead.
|
||||
|
||||
By setting the `reverse` kwarg to `True`, the cumulative log-sum-exp is performed in the
|
||||
opposite direction.
|
||||
END
|
||||
}
|
@ -18,6 +18,9 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/kernels/scan_ops.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -25,10 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/kernels/scan_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
@ -107,8 +106,12 @@ namespace functor {
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
|
||||
DECLARE_FOR_ALL_REDUCERS(int32);
|
||||
DECLARE_FOR_ALL_REDUCERS(int64);
|
||||
|
||||
#undef DECLARE_FOR_ALL_REDUCERS
|
||||
|
||||
#define DECLARE_FOR_LOGSUMEXP_REDUCER(T) DECLARE(LogSumExpReducer<T>, T);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_LOGSUMEXP_REDUCER)
|
||||
#undef DECLARE_FOR_LOGSUMEXP_REDUCER
|
||||
|
||||
#undef DECLARE
|
||||
|
||||
} // namespace functor
|
||||
@ -192,4 +195,31 @@ REGISTER_GPU_KERNELS(int64);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_CUMLOGSUMEXP_KERNEL(device, device_type, type, type_idx) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("CumulativeLogsumexp") \
|
||||
.Device(device) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<type_idx>("Tidx") \
|
||||
.HostMemory("axis"), \
|
||||
ScanOp<device_type, type, functor::LogSumExpReducer<type>, type_idx>)
|
||||
|
||||
#define REGISTER_CPU_KERNELS(type) \
|
||||
REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_CPU, CPUDevice, type, int32) \
|
||||
REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_CPU, CPUDevice, type, int64)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_CPU_KERNELS);
|
||||
#undef REGISTER_CPU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER_GPU_KERNELS(type) \
|
||||
REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_GPU, GPUDevice, type, int32) \
|
||||
REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_GPU, GPUDevice, type, int64)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER_CUMLOGSUMEXP_KERNEL
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -40,6 +40,41 @@ struct Scan {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LogSumExp {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
|
||||
const T& b) const {
|
||||
Eigen::internal::scalar_sum_op<T> sum_op;
|
||||
Eigen::internal::scalar_exp_op<T> exp_op;
|
||||
Eigen::internal::scalar_log_op<T> log_op;
|
||||
Eigen::internal::scalar_max_op<T> max_op;
|
||||
Eigen::internal::scalar_min_op<T> min_op;
|
||||
Eigen::internal::scalar_log1p_op<T> log1p_op;
|
||||
Eigen::internal::scalar_difference_op<T> diff_op;
|
||||
|
||||
auto mi = min_op(a, b);
|
||||
auto ma = max_op(a, b);
|
||||
|
||||
return sum_op(log1p_op(exp_op(diff_op(mi, ma))), ma);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LogSumExpReducer {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
|
||||
LogSumExp<T> logsumexp;
|
||||
*accum = logsumexp(*accum, t);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
|
||||
return Eigen::NumTraits<T>::lowest();
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
|
||||
return accum;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -143,9 +143,16 @@ struct IsProd {
|
||||
std::is_same<Op, Eigen::internal::ProdReducer<T>>::value);
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
struct IsLogSumExp {
|
||||
constexpr static bool value = (std::is_same<Op, LogSumExp<T>>::value ||
|
||||
std::is_same<Op, LogSumExpReducer<T>>::value);
|
||||
};
|
||||
|
||||
template <typename T, typename Op>
|
||||
struct IdentityValue {
|
||||
static_assert(IsSum<T, Op>::value || IsProd<T, Op>::value,
|
||||
static_assert(IsSum<T, Op>::value || IsProd<T, Op>::value ||
|
||||
IsLogSumExp<T, Op>::value,
|
||||
"IdentityValue not yet defined for this type.");
|
||||
|
||||
template <typename U = T, typename OpCopy = Op>
|
||||
@ -159,6 +166,13 @@ struct IdentityValue {
|
||||
typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) {
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename U = T, typename OpCopy = Op>
|
||||
__host__ __device__ U
|
||||
operator()(typename std::enable_if<IsLogSumExp<U, OpCopy>::value, U>::type t =
|
||||
U(Eigen::NumTraits<U>::lowest())) {
|
||||
return t;
|
||||
}
|
||||
};
|
||||
|
||||
// Each block is mapped to one sequence. A contiguous range is mapped to the
|
||||
@ -311,6 +325,16 @@ struct Scan<GPUDevice, Eigen::internal::ProdReducer<T>, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Scan<GPUDevice, LogSumExpReducer<T>, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
|
||||
typename TTypes<T, 3>::Tensor out,
|
||||
const LogSumExpReducer<T>& reducer, const bool reverse,
|
||||
const bool exclusive) {
|
||||
LaunchScan<T, LogSumExp<T>>(d, in, out, LogSumExp<T>(), reverse, exclusive);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -26,6 +26,8 @@ template struct functor::Scan<GpuDevice, Eigen::internal::SumReducer<double>,
|
||||
double>;
|
||||
template struct functor::Scan<GpuDevice, Eigen::internal::ProdReducer<double>,
|
||||
double>;
|
||||
template struct functor::Scan<GpuDevice, functor::LogSumExpReducer<double>,
|
||||
double>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -26,6 +26,8 @@ template struct functor::Scan<GpuDevice, Eigen::internal::SumReducer<float>,
|
||||
float>;
|
||||
template struct functor::Scan<GpuDevice, Eigen::internal::ProdReducer<float>,
|
||||
float>;
|
||||
template struct functor::Scan<GpuDevice, functor::LogSumExpReducer<float>,
|
||||
float>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -26,6 +26,8 @@ template struct functor::Scan<
|
||||
GpuDevice, Eigen::internal::SumReducer<Eigen::half>, Eigen::half>;
|
||||
template struct functor::Scan<
|
||||
GpuDevice, Eigen::internal::ProdReducer<Eigen::half>, Eigen::half>;
|
||||
template struct functor::Scan<GpuDevice, functor::LogSumExpReducer<Eigen::half>,
|
||||
Eigen::half>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -1597,6 +1597,16 @@ REGISTER_OP("Cumprod")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("CumulativeLogsumexp")
|
||||
.Input("x : T")
|
||||
.Input("axis: Tidx")
|
||||
.Attr("exclusive: bool = false")
|
||||
.Attr("reverse: bool = false")
|
||||
.Output("out: T")
|
||||
.Attr("T: {float16, float32, float64}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("QuantizedMatMul")
|
||||
.Input("a: T1")
|
||||
.Input("b: T2")
|
||||
|
@ -316,6 +316,22 @@ cuda_py_test(
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "cumulative_logsumexp_test",
|
||||
size = "medium",
|
||||
srcs = ["cumulative_logsumexp_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:map_fn",
|
||||
"//tensorflow/python:array_ops",
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "decode_csv_op_test",
|
||||
size = "small",
|
||||
|
114
tensorflow/python/kernel_tests/cumulative_logsumexp_test.py
Normal file
114
tensorflow/python/kernel_tests/cumulative_logsumexp_test.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functional tests for cumulative_logsumexp op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CumulativeLogsumexpTest(test.TestCase):
|
||||
valid_dtypes = [dtypes.float32, dtypes.float64]
|
||||
|
||||
def _computeLogSumExp(self, x, **kwargs):
|
||||
result_naive = math_ops.cumsum(math_ops.exp(x), **kwargs)
|
||||
result_fused = math_ops.exp(math_ops.cumulative_logsumexp(x, **kwargs))
|
||||
return result_naive, result_fused
|
||||
|
||||
def _testLogSumExp(self, x, dtype=dtypes.float32, use_gpu=False, **kwargs):
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
x = ops.convert_to_tensor(x, dtype=dtype)
|
||||
|
||||
result_naive, result_fused = self.evaluate(
|
||||
self._computeLogSumExp(x, **kwargs))
|
||||
|
||||
self.assertAllClose(result_naive, result_fused)
|
||||
|
||||
def _testLogSumExpAllArgs(self, x, axis=0, use_gpu=False):
|
||||
for dtype in self.valid_dtypes:
|
||||
for reverse in (True, False):
|
||||
for exclusive in (True, False):
|
||||
self._testLogSumExp(
|
||||
x, dtype=dtype, use_gpu=use_gpu,
|
||||
reverse=reverse, exclusive=exclusive,
|
||||
axis=axis)
|
||||
|
||||
def test1D(self):
|
||||
x = np.arange(10) / 10.0 - 0.5
|
||||
self._testLogSumExpAllArgs(x, use_gpu=False)
|
||||
self._testLogSumExpAllArgs(x, use_gpu=True)
|
||||
|
||||
def test2D(self):
|
||||
x = np.reshape(np.arange(20) / 20.0 - 0.5, (2, 10))
|
||||
|
||||
for axis in (-2, -1, 0, 1):
|
||||
self._testLogSumExpAllArgs(x, axis=axis, use_gpu=False)
|
||||
self._testLogSumExpAllArgs(x, axis=axis, use_gpu=True)
|
||||
|
||||
def _testGradient(self, x, use_gpu=False, **kwargs):
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
x = ops.convert_to_tensor(x, dtype=dtypes.float64)
|
||||
|
||||
grad_naive_theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||
lambda y: math_ops.cumsum(math_ops.exp(y), **kwargs), [x])
|
||||
grad_fused_theoretical, _ = gradient_checker_v2.compute_gradient(
|
||||
lambda y: math_ops.exp(math_ops.cumulative_logsumexp(y, **kwargs)),
|
||||
[x])
|
||||
|
||||
self.assertAllClose(grad_fused_theoretical, grad_naive_theoretical)
|
||||
|
||||
def testGradient(self):
|
||||
for reverse in (True, False):
|
||||
for exclusive in (True, False):
|
||||
x = np.arange(10) / 10.0 - 0.5
|
||||
|
||||
self._testGradient(x, use_gpu=False,
|
||||
reverse=reverse, exclusive=exclusive)
|
||||
self._testGradient(x, use_gpu=True,
|
||||
reverse=reverse, exclusive=exclusive)
|
||||
|
||||
def _logSumExpMap(self, x):
|
||||
return map_fn.map_fn(
|
||||
lambda i: math_ops.reduce_logsumexp(x[:i + 1]),
|
||||
math_ops.range(array_ops.shape(x)[0]),
|
||||
dtype=x.dtype)
|
||||
|
||||
def test1DLarge(self):
|
||||
# This test ensures that the operation is correct even when the naive
|
||||
# implementation would overflow.
|
||||
x_np = np.arange(20) * 20.0
|
||||
|
||||
for use_gpu in (True, False):
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
x_tf = ops.convert_to_tensor(x_np, dtype=dtypes.float32)
|
||||
|
||||
result_fused = self.evaluate(math_ops.cumulative_logsumexp(x_tf))
|
||||
result_map = self.evaluate(self._logSumExpMap(x_tf))
|
||||
|
||||
self.assertAllClose(result_fused, result_map)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1641,6 +1641,40 @@ def _CumprodGrad(op, grad):
|
||||
return [out / x, None]
|
||||
|
||||
|
||||
@ops.RegisterGradient("CumulativeLogsumexp")
|
||||
def _CumulativeLogsumexpGrad(op, grad):
|
||||
x = op.inputs[0]
|
||||
axis = op.inputs[1]
|
||||
cumulative_logsumexp = op.outputs[0]
|
||||
|
||||
exclusive = op.get_attr("exclusive")
|
||||
reverse = op.get_attr("reverse")
|
||||
|
||||
# Split the incoming gradient into positive and negative part
|
||||
# in order to take logs. This is required for stable results.
|
||||
log_grad_positive = array_ops.where_v2(
|
||||
math_ops.greater(grad, 0),
|
||||
math_ops.log(grad),
|
||||
grad.dtype.min)
|
||||
|
||||
log_grad_negative = array_ops.where_v2(
|
||||
math_ops.less(grad, 0),
|
||||
math_ops.log(-grad),
|
||||
grad.dtype.min)
|
||||
|
||||
output_pos = math_ops.exp(
|
||||
math_ops.cumulative_logsumexp(
|
||||
log_grad_positive - cumulative_logsumexp,
|
||||
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
|
||||
|
||||
output_neg = math_ops.exp(
|
||||
math_ops.cumulative_logsumexp(
|
||||
log_grad_negative - cumulative_logsumexp,
|
||||
axis=axis, reverse=not reverse, exclusive=exclusive) + x)
|
||||
|
||||
return [output_pos - output_neg, None]
|
||||
|
||||
|
||||
@ops.RegisterGradient("NextAfter")
|
||||
def _NextAfterGrad(op, grad):
|
||||
"""Returns gradient of nextafter(x1, x2) with respect to x1 and x2."""
|
||||
|
@ -3297,6 +3297,61 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
|
||||
x, axis, exclusive=exclusive, reverse=reverse, name=name)
|
||||
|
||||
|
||||
@tf_export("math.cumulative_logsumexp", v1=["math.cumulative_logsumexp"])
|
||||
def cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None):
|
||||
"""Compute the cumulative log-sum-exp of the tensor `x` along `axis`.
|
||||
|
||||
By default, this op performs an inclusive cumulative log-sum-exp, which means
|
||||
that the first element of the input is identical to the first element of
|
||||
the output.
|
||||
|
||||
This operation is significantly more numerically stable than the equivalent
|
||||
tensorflow operation `tf.math.log(tf.math.cumsum(tf.math.exp(x)))`, although
|
||||
computes the same result given infinite numerical precision. However, note
|
||||
that in some cases, it may be less stable than `tf.math.reduce_logsumexp`
|
||||
for a given element, as it applies the "log-sum-exp trick" in a different
|
||||
way.
|
||||
|
||||
More precisely, where `tf.math.reduce_logsumexp` uses the following trick:
|
||||
|
||||
```
|
||||
log(sum(exp(x))) == log(sum(exp(x - max(x)))) + max(x)
|
||||
```
|
||||
|
||||
it cannot be directly used here as there is no fast way of applying it
|
||||
to each prefix `x[:i]`. Instead, this function implements a prefix
|
||||
scan using pairwise log-add-exp, which is a commutative and associative
|
||||
(up to floating point precision) operator:
|
||||
|
||||
```
|
||||
log_add_exp(x, y) = log(exp(x) + exp(y))
|
||||
= log(1 + exp(min(x, y) - max(x, y))) + max(x, y)
|
||||
```
|
||||
|
||||
However, reducing using the above operator leads to a different computation
|
||||
tree (logs are taken repeatedly instead of only at the end), and the maximum
|
||||
is only computed pairwise instead of over the entire prefix. In general, this
|
||||
leads to a different and slightly less precise computation.
|
||||
|
||||
Args:
|
||||
x: A `Tensor`. Must be one of the following types: `float16`, `float32`,
|
||||
`float64`.
|
||||
axis: A `Tensor` of type `int32` or `int64` (default: 0). Must be in the
|
||||
range `[-rank(x), rank(x))`.
|
||||
exclusive: If `True`, perform exclusive cumulative log-sum-exp.
|
||||
reverse: If `True`, performs the cumulative log-sum-exp in the reverse
|
||||
direction.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Has the same shape and type as `x`.
|
||||
"""
|
||||
with ops.name_scope(name, "CumulativeLogsumexp", [x]) as name:
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
return gen_math_ops.cumulative_logsumexp(
|
||||
x, axis, exclusive=exclusive, reverse=reverse, name=name)
|
||||
|
||||
|
||||
@tf_export("math.conj", v1=["math.conj", "conj"])
|
||||
@dispatch.add_dispatch_support
|
||||
@deprecation.deprecated_endpoints("conj")
|
||||
|
@ -112,6 +112,10 @@ tf_module {
|
||||
name: "cumsum"
|
||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "cumulative_logsumexp"
|
||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "digamma"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -844,6 +844,10 @@ tf_module {
|
||||
name: "Cumsum"
|
||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CumulativeLogsumexp"
|
||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DataFormatDimMap"
|
||||
argspec: "args=[\'x\', \'src_format\', \'dst_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'NCHW\', \'None\'], "
|
||||
|
@ -112,6 +112,10 @@ tf_module {
|
||||
name: "cumsum"
|
||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "cumulative_logsumexp"
|
||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "digamma"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -844,6 +844,10 @@ tf_module {
|
||||
name: "Cumsum"
|
||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CumulativeLogsumexp"
|
||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DataFormatDimMap"
|
||||
argspec: "args=[\'x\', \'src_format\', \'dst_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'NCHW\', \'None\'], "
|
||||
|
Loading…
x
Reference in New Issue
Block a user