Add log1p (#5356)
* Add log1p. * Use log1p in cross entropy. * Register log1p in math_ops. * Update copyright. * Fix tests.
This commit is contained in:
parent
e1d1167b64
commit
7b7c02de56
@ -58,7 +58,7 @@ class MultiClassTargetColumnTest(tf.test.TestCase):
|
||||
labels = tf.constant([[1.], [0.]])
|
||||
# logloss: z:label, x:logit
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
self.assertAlmostEqual(.81326163,
|
||||
self.assertAlmostEqual(0.81326175,
|
||||
sess.run(target_column.loss(logits, labels, {})))
|
||||
|
||||
def testBinaryClassificationWithWeights(self):
|
||||
|
@ -74,7 +74,7 @@ class MultiClassModelHeadTest(tf.test.TestCase):
|
||||
model_fn_ops = head.head_ops({}, labels,
|
||||
tf.contrib.learn.ModeKeys.TRAIN,
|
||||
_noop_train_op, logits=logits)
|
||||
self.assertAlmostEqual(.81326163, sess.run(model_fn_ops.loss))
|
||||
self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss))
|
||||
|
||||
def testErrorInSparseTensorLabels(self):
|
||||
head = head_lib._multi_class_head(n_classes=2)
|
||||
|
26
tensorflow/core/kernels/cwise_op_gpu_log1p.cu.cc
Normal file
26
tensorflow/core/kernels/cwise_op_gpu_log1p.cu.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
DEFINE_UNARY3(log1p, Eigen::half, float, double);
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
24
tensorflow/core/kernels/cwise_op_log1p.cc
Normal file
24
tensorflow/core/kernels/cwise_op_log1p.cc
Normal file
@ -0,0 +1,24 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER5(UnaryOp, CPU, "Log1p", functor::log1p, float, Eigen::half, double,
|
||||
complex64, complex128);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER3(UnaryOp, GPU, "Log1p", functor::log1p, float, Eigen::half, double);
|
||||
#endif
|
||||
} // namespace tensorflow
|
@ -454,6 +454,9 @@ struct exp : base<T, Eigen::internal::scalar_exp_op<T> > {};
|
||||
template <typename T>
|
||||
struct log : base<T, Eigen::internal::scalar_log_op<T> > {};
|
||||
|
||||
template <typename T>
|
||||
struct log1p : base<T, Eigen::internal::scalar_log1p_op<T> > {};
|
||||
|
||||
template <typename T>
|
||||
struct sign : base<T, Eigen::internal::scalar_sign_op<T> > {};
|
||||
|
||||
|
@ -296,6 +296,13 @@ Computes natural logarithm of x element-wise.
|
||||
I.e., \\(y = \log_e x\\).
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Log1p")
|
||||
.UNARY_COMPLEX()
|
||||
.Doc(R"doc(
|
||||
Computes natural logarithm of (1 + x) element-wise.
|
||||
I.e., \\(y = \log_e (1 + x)\\).
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Tanh")
|
||||
.UNARY_COMPLEX()
|
||||
.Doc(R"doc(
|
||||
|
@ -194,6 +194,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBoth(z, self._rsqrt, tf.rsqrt)
|
||||
self._compareBoth(x, np.exp, tf.exp)
|
||||
self._compareBoth(z, np.log, tf.log)
|
||||
self._compareBoth(z, np.log1p, tf.log1p)
|
||||
self._compareBoth(x, np.tanh, tf.tanh)
|
||||
self._compareBoth(x, self._sigmoid, tf.sigmoid)
|
||||
self._compareBoth(y, np.sign, tf.sign)
|
||||
@ -236,6 +237,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBoth(x, self._rsqrt, tf.rsqrt)
|
||||
self._compareBoth(x, np.exp, tf.exp)
|
||||
self._compareBoth(x, np.log, tf.log)
|
||||
self._compareBoth(x, np.log1p, tf.log1p)
|
||||
self._compareBoth(x, np.tanh, tf.tanh)
|
||||
self._compareBoth(x, self._sigmoid, tf.sigmoid)
|
||||
self._compareBoth(x, np.sign, tf.sign)
|
||||
@ -273,6 +275,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBoth(z, self._rsqrt, tf.rsqrt)
|
||||
self._compareBoth(x, np.exp, tf.exp)
|
||||
self._compareBoth(z, np.log, tf.log)
|
||||
self._compareBoth(z, np.log1p, tf.log1p)
|
||||
self._compareBoth(x, np.tanh, tf.tanh)
|
||||
self._compareBoth(x, self._sigmoid, tf.sigmoid)
|
||||
self._compareBoth(y, np.sign, tf.sign)
|
||||
@ -311,6 +314,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBoth(z, self._rsqrt, tf.rsqrt)
|
||||
self._compareBoth(x, np.exp, tf.exp)
|
||||
self._compareBoth(z, np.log, tf.log)
|
||||
self._compareBoth(z, np.log1p, tf.log1p)
|
||||
self._compareBoth(x, np.tanh, tf.tanh)
|
||||
self._compareBoth(x, self._sigmoid, tf.sigmoid)
|
||||
self._compareBoth(y, np.sign, tf.sign)
|
||||
@ -374,6 +378,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareCpu(y, self._rsqrt, tf.rsqrt)
|
||||
self._compareCpu(x, np.exp, tf.exp)
|
||||
self._compareCpu(y, np.log, tf.log)
|
||||
self._compareCpu(y, np.log1p, tf.log1p)
|
||||
self._compareCpu(x, np.tanh, tf.tanh)
|
||||
self._compareCpu(x, self._sigmoid, tf.sigmoid)
|
||||
self._compareCpu(x, np.sin, tf.sin)
|
||||
@ -405,6 +410,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareCpu(y, self._rsqrt, tf.rsqrt)
|
||||
self._compareCpu(x, np.exp, tf.exp)
|
||||
self._compareCpu(y, np.log, tf.log)
|
||||
self._compareCpu(y, np.log1p, tf.log1p)
|
||||
self._compareCpu(x, np.tanh, tf.tanh)
|
||||
self._compareCpu(x, self._sigmoid, tf.sigmoid)
|
||||
self._compareCpu(x, np.sin, tf.sin)
|
||||
|
@ -326,6 +326,15 @@ def _LogGrad(op, grad):
|
||||
return grad * math_ops.inv(x)
|
||||
|
||||
|
||||
@ops.RegisterGradient("Log1p")
|
||||
def _Log1pGrad(op, grad):
|
||||
"""Returns grad * (1/(1 + x))."""
|
||||
x = op.inputs[0]
|
||||
with ops.control_dependencies([grad.op]):
|
||||
x = math_ops.conj(x)
|
||||
return grad * math_ops.inv(1 + x)
|
||||
|
||||
|
||||
@ops.RegisterGradient("Tanh")
|
||||
def _TanhGrad(op, grad):
|
||||
"""Returns grad * (1 - tanh(x) * tanh(x))."""
|
||||
|
@ -51,6 +51,7 @@ mathematical functions to your graph.
|
||||
@@pow
|
||||
@@exp
|
||||
@@log
|
||||
@@log1p
|
||||
@@ceil
|
||||
@@floor
|
||||
@@maximum
|
||||
@ -1885,6 +1886,7 @@ ops.RegisterShape("IsFinite")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("IsInf")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("IsNan")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Log")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Log1p")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("LogicalNot")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Neg")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Real")(common_shapes.call_cpp_shape_fn)
|
||||
|
@ -454,7 +454,7 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
|
||||
relu_logits = math_ops.select(cond, logits, zeros)
|
||||
neg_abs_logits = math_ops.select(cond, -logits, logits)
|
||||
return math_ops.add(relu_logits - logits * targets,
|
||||
math_ops.log(1 + math_ops.exp(neg_abs_logits)),
|
||||
math_ops.log1p(math_ops.exp(neg_abs_logits)),
|
||||
name=name)
|
||||
|
||||
|
||||
@ -522,7 +522,7 @@ def weighted_cross_entropy_with_logits(logits, targets, pos_weight, name=None):
|
||||
log_weight = 1 + (pos_weight - 1) * targets
|
||||
return math_ops.add(
|
||||
(1 - targets) * logits,
|
||||
log_weight * (math_ops.log(1 + math_ops.exp(-math_ops.abs(logits))) +
|
||||
log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
|
||||
nn_ops.relu(-logits)),
|
||||
name=name)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user