Add the poisson log loss to the SDCA optimizer.

PiperOrigin-RevId: 211116606
This commit is contained in:
A. Unique TensorFlower 2018-08-31 11:30:49 -07:00 committed by TensorFlower Gardener
parent 86ed8fada2
commit e894ca7c73
8 changed files with 282 additions and 2 deletions

View File

@ -199,6 +199,46 @@ does.
However, in practice, convergence with $$x_0 = 0$$ always happens (tested for a
sample of generic values for the parameters).
### Poisson log loss
Poisson log loss is defined as $$ \l(u) = e^u - uy $$ for label $$y \geq 0.$$
Its dual is
$$ \l^\star(v) = (y+v) (\log(y+v) - 1) $$
and is only defined for $$ y+v > 0 $$. We then have the constraint
$$ y > \a+\d. $$
The dual is
$$ D(\d) = -(y-\a-\d) (\log(y-\a-\d) - 1) - \bar{y} \d - \frac{A}{2} \d^2 $$
and its derivative is,
$$ D'(\d) = \log(y-\a-\d) - \bar{y} - A\d $$
Similar to the logistic loss, we perform a change of variable to handle the
constraint on $$ \d $$
$$ y - (\a+\d) = e^x $$
After this change of variable, the goal is to find the zero of this function
$$ H(x) = x - \bar{y} -A(y-\a-e^x) $$
whose first derivative is
$$ H'(x) = 1+Ae^x $$
Since this function is always positive, $$H$$ is increasing and has a unique
zero.
We can start Newton algorithm at $$\d=0$$ which corresponds to $$ x =
\log(y-\a)$$. As before the Newton step is given by
$$x_{k+1} = x_k - \frac{H(x_k)}{H'(x_k)}. $$
### References
[1] C. Ma et al., Adding vs. Averaging in Distributed Primal-Dual Optimization,

View File

@ -1192,6 +1192,57 @@ class SdcaWithSmoothHingeLossTest(SdcaModelTest):
self.assertAllClose(0.33, unregularized_loss.eval(), atol=0.02)
self.assertAllClose(0.44, regularized_loss.eval(), atol=0.02)
class SdcaWithPoissonLossTest(SdcaModelTest):
"""SDCA optimizer test class for poisson loss."""
def testSimple(self):
# Setup test data
example_protos = [
make_example_proto({
'age': [0],
'gender': [0]
}, 0),
make_example_proto({
'age': [1],
'gender': [1]
}, 2),
]
example_weights = [100.0, 100.0]
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(1, 1)
options = dict(
symmetric_l2_regularization=1.0,
symmetric_l1_regularization=0,
loss_type='poisson_loss')
model = SdcaModel(examples, variables, options)
variables_lib.global_variables_initializer().run()
# Before minimization, the weights default to zero. There is no loss due
# to regularization, only unregularized loss which is 1 for each example.
predictions = model.predictions(examples)
self.assertAllClose([1.0, 1.0], predictions.eval())
unregularized_loss = model.unregularized_loss(examples)
regularized_loss = model.regularized_loss(examples)
approximate_duality_gap = model.approximate_duality_gap()
self.assertAllClose(1.0, unregularized_loss.eval())
self.assertAllClose(1.0, regularized_loss.eval())
# There are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender
# (say w3 and w4). The minimization leads to:
# w1=w3=-1.96487, argmin of 100*(exp(2*w)-2*w*0)+w**2.
# w2=w4=0.345708, argmin of 100*(exp(2*w)-2*w*2)+w**2.
# This gives an unregularized loss of .3167 and .3366 with regularization.
train_op = model.minimize()
for _ in range(_MAX_ITERATIONS):
train_op.run()
model.update_weights(train_op).run()
self.assertAllClose([0.0196, 1.9965], predictions.eval(), atol=1e-4)
self.assertAllClose(0.3167, unregularized_loss.eval(), atol=1e-4)
self.assertAllClose(0.3366, regularized_loss.eval(), atol=1e-4)
self.assertAllClose(0., approximate_duality_gap.eval(), atol=1e-6)
class SdcaFprintTest(SdcaModelTest):
"""Tests for the SdcaFprint op.

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as var_ops
from tensorflow.python.ops.nn import log_poisson_loss
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
from tensorflow.python.summary import summary
@ -51,6 +52,7 @@ class SdcaModel(object):
* Squared loss
* Hinge loss
* Smooth hinge loss
* Poisson log loss
This class defines an optimizer API to train a linear model.
@ -112,7 +114,7 @@ class SdcaModel(object):
raise ValueError('examples, variables and options must all be specified.')
supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss',
'smooth_hinge_loss')
'smooth_hinge_loss', 'poisson_loss')
if options['loss_type'] not in supported_losses:
raise ValueError('Unsupported loss_type: ', options['loss_type'])
@ -315,6 +317,7 @@ class SdcaModel(object):
"""Add operations to compute predictions by the model.
If logistic_loss is being used, predicted probabilities are returned.
If poisson_loss is being used, predictions are exponentiated.
Otherwise, (raw) linear predictions (w*x) are returned.
Args:
@ -335,6 +338,10 @@ class SdcaModel(object):
# Convert logits to probability for logistic loss predictions.
with name_scope('sdca/logistic_prediction'):
result = math_ops.sigmoid(result)
elif self._options['loss_type'] == 'poisson_loss':
# Exponeniate the prediction for poisson loss predictions.
with name_scope('sdca/poisson_prediction'):
result = math_ops.exp(result)
return result
def _get_partitioned_update_ops(self,
@ -624,6 +631,11 @@ class SdcaModel(object):
logits=predictions),
weights)) / math_ops.reduce_sum(weights)
if self._options['loss_type'] == 'poisson_loss':
return math_ops.reduce_sum(math_ops.multiply(
log_poisson_loss(targets=labels, log_input=predictions),
weights)) / math_ops.reduce_sum(weights)
if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']:
# hinge_loss = max{0, 1 - y_i w*x} where y_i \in {-1, 1}. So, we need to
# first convert 0/1 labels into -1/1 labels.

View File

@ -4196,6 +4196,7 @@ cc_library(
"hinge-loss.h",
"logistic-loss.h",
"loss.h",
"poisson-loss.h",
"smooth-hinge-loss.h",
"squared-loss.h",
],

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
#include "tensorflow/core/lib/core/errors.h"
@ -288,5 +289,68 @@ TEST(SmoothHingeLoss, ComputeUpdatedDual) {
0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(PoissonLoss, ComputePrimalLoss) {
PoissonLossUpdater loss_updater;
EXPECT_NEAR(1.0,
loss_updater.ComputePrimalLoss(0.0 /* wx */, 3.0 /* label */,
1.0 /* example weight */),
1e-3);
EXPECT_NEAR(21996.0,
loss_updater.ComputePrimalLoss(10.0 /* wx */, 3.0 /* label */,
1.0 /* example weight */),
1.0);
EXPECT_NEAR(0.606,
loss_updater.ComputePrimalLoss(-0.5 /* wx */, 0.0 /* label */,
1.0 /* example weight */),
1e-3);
EXPECT_NEAR(6.64,
loss_updater.ComputePrimalLoss(1.2 /* wx */, 0.0 /* label */,
2.0 /* example weight */),
1e-2);
}
TEST(PoissonLoss, ComputeDualLoss) {
PoissonLossUpdater loss_updater;
// Dual is undefined.
EXPECT_NEAR(
std::numeric_limits<double>::max(),
loss_updater.ComputeDualLoss(1.0 /* current dual */, 0.0 /* label */,
1.0 /* example weight */),
1e-3);
EXPECT_NEAR(
0.0,
loss_updater.ComputeDualLoss(0.0 /* current dual */, 0.0 /* label */,
3.0 /* example weight */),
1e-3);
EXPECT_NEAR(
-0.847,
loss_updater.ComputeDualLoss(1.5 /* current dual */, 2.0 /* label */,
1.0 /* example weight */),
1e-3);
EXPECT_NEAR(
-2.675,
loss_updater.ComputeDualLoss(0.5 /* current dual */, 2.0 /* label */,
3.0 /* example weight */),
1e-3);
}
TEST(PoissonLoss, ConvertLabel) {
PoissonLossUpdater loss_updater;
float example_label = -1.0;
// Negative label should throw an error.
Status status = loss_updater.ConvertLabel(&example_label);
EXPECT_FALSE(status.ok());
}
TEST(PoissonLoss, ComputeUpdatedDual) {
PoissonLossUpdater loss_updater;
TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 2.0 /* label */,
1.0 /* example weight */, 0.5 /* current_dual */,
0.3 /* wx */, 10.0 /* weighted_example_norm */);
TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, 0.0 /* label */,
1.0 /* example weight */, 0.0 /* current_dual */,
-0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,109 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
#define TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
#include <cmath>
#include "tensorflow/core/kernels/loss.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
class PoissonLossUpdater : public DualLossUpdater {
public:
// Update is found by a Newton algorithm (see readme.md).
double ComputeUpdatedDual(const int num_loss_partitions, const double label,
const double example_weight,
const double current_dual, const double wx,
const double weighted_example_norm) const final {
// Newton algorithm converges quadratically so 10 steps will be largely
// enough to achieve a very good precision
static const int newton_total_steps = 10;
// Initialize the Newton optimization at x such that
// exp(x) = label - current_dual
const double y_minus_a = label - current_dual;
double x = (y_minus_a > 0) ? log(y_minus_a) : 0;
for (int i = 0; i < newton_total_steps; ++i) {
x = NewtonStep(x, num_loss_partitions, label, wx, example_weight,
weighted_example_norm, current_dual);
}
return label - exp(x);
}
// Dual of poisson loss function.
// https://en.wikipedia.org/wiki/Convex_conjugate
double ComputeDualLoss(const double current_dual, const double example_label,
const double example_weight) const final {
// Dual of the poisson loss function is
// (y-a)*(log(y-a)-1), where a is the dual variable.
// It is defined only for a<y.
const double y_minus_a = example_label - current_dual;
if (y_minus_a == 0.0) {
// (y-a)*(log(y-a)-1) approaches 0 as y-a approaches 0.
return 0.0;
}
if (y_minus_a < 0.0) {
return std::numeric_limits<double>::max();
}
return y_minus_a * (log(y_minus_a) - 1) * example_weight;
}
double ComputePrimalLoss(const double wx, const double example_label,
const double example_weight) const final {
return (exp(wx) - wx * example_label) * example_weight;
}
double PrimalLossDerivative(const double wx, const double label,
const double example_weight) const final {
return (exp(wx) - label) * example_weight;
}
// TODO(chapelle): We need to introduce a maximum_prediction parameter,
// expose that parameter to the user and have this method return
// 1.0/maximum_prediction.
// Setting this at 1 for now, it only impacts the adaptive sampling.
double SmoothnessConstant() const final { return 1; }
Status ConvertLabel(float* const example_label) const final {
if (*example_label < 0.0) {
return errors::InvalidArgument(
"Only non-negative labels can be used with the Poisson log loss. "
"Found example with label: ", *example_label);
}
return Status::OK();
}
private:
// One Newton step (see readme.md).
double NewtonStep(const double x, const int num_loss_partitions,
const double label, const double wx,
const double example_weight,
const double weighted_example_norm,
const double current_dual) const {
const double expx = exp(x);
const double numerator =
x - wx - num_loss_partitions * weighted_example_norm *
example_weight * (label - current_dual - expx);
const double denominator =
1 + num_loss_partitions * weighted_example_norm * example_weight * expx;
return x - numerator / denominator;
}
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
#include "tensorflow/core/kernels/loss.h"
#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/sdca_internal.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
@ -75,6 +76,8 @@ struct ComputeOptions {
loss_updater.reset(new HingeLossUpdater);
} else if (loss_type == "smooth_hinge_loss") {
loss_updater.reset(new SmoothHingeLossUpdater);
} else if (loss_type == "poisson_loss") {
loss_updater.reset(new PoissonLossUpdater);
} else {
OP_REQUIRES(
context, false,

View File

@ -41,7 +41,7 @@ static Status ApplySdcaOptimizerShapeFn(InferenceContext* c) {
REGISTER_OP("SdcaOptimizer")
.Attr(
"loss_type: {'logistic_loss', 'squared_loss', 'hinge_loss',"
"'smooth_hinge_loss'}")
"'smooth_hinge_loss', 'poisson_loss'}")
.Attr("adaptative : bool=false")
.Attr("num_sparse_features: int >= 0")
.Attr("num_sparse_features_with_values: int >= 0")