Fixes bug in accumulation of total-approximate-duality-gap.

Change: 115528686
This commit is contained in:
A. Unique TensorFlower 2016-02-24 22:06:25 -08:00 committed by TensorFlower Gardener
parent 73d557cc88
commit e752109efb

View File

@ -365,19 +365,20 @@ class SdcaSolver : public OpKernel {
std::iota(example_ids.begin(), example_ids.end(), 0);
std::random_device random_device;
std::mt19937 random_generator(random_device());
std::atomic<double> total_approx_duality_gap(
std::numeric_limits<double>::max());
std::atomic<double> total_primal_loss(0);
std::atomic<double> total_dual_loss(0);
// Break when duality gap |P(w) - D(alpha)| is less than
// duality_gap_threshold_
double total_approx_duality_gap = std::numeric_limits<double>::max();
while ((total_approx_duality_gap / weighted_examples) >
duality_gap_threshold_) {
// Reset and add everything.
total_approx_duality_gap = 0;
// Reset accumulated losses.
total_primal_loss = 0;
total_dual_loss = 0;
std::shuffle(example_ids.begin(), example_ids.end(), random_generator);
auto do_update = [&](const int64 begin, const int64 end) {
double approx_duality_gap = 0;
double dual_loss_on_example_subset = 0;
double primal_loss_on_example_subset = 0;
for (int64 offset = begin; offset < end; ++offset) {
// Get example id, label, and weight.
const int64 example_id = example_ids[offset];
@ -400,9 +401,10 @@ class SdcaSolver : public OpKernel {
const double dual_loss = compute_dual_loss_(
current_dual, example_label, example_weight);
dual_loss_on_example_subset += dual_loss;
const double primal_loss = compute_primal_loss_(
per_example_data.wx, example_label, example_weight);
approx_duality_gap += dual_loss + primal_loss;
primal_loss_on_example_subset += primal_loss;
// Update dual variable.
dual_variables(example_id) = compute_dual_update_(
@ -415,10 +417,9 @@ class SdcaSolver : public OpKernel {
UpdateWeights(example_id, sparse_examples_by_index, dense_features,
bounded_dual_delta, regularizations_.symmetric_l2,
&sparse_weights_by_index, &dense_weights_by_index);
AtomicAdd(approx_duality_gap, &total_approx_duality_gap);
AtomicAdd(primal_loss, &total_primal_loss);
}
AtomicAdd(primal_loss_on_example_subset, &total_primal_loss);
AtomicAdd(dual_loss_on_example_subset, &total_dual_loss);
// TODO(rohananil): We may in the future want to make the primal-dual
// relationship consistent as our current updates are not transactional.
};
@ -430,9 +431,9 @@ class SdcaSolver : public OpKernel {
kCostPerUnit, do_update);
const RegularizationLoss regularization_loss = ComputeRegularizationLoss(
sparse_weights_by_index, dense_weights_by_index, regularizations_);
total_approx_duality_gap.store(total_approx_duality_gap.load() +
regularization_loss.l1_loss +
regularization_loss.l2_loss);
total_approx_duality_gap =
total_primal_loss.load() + total_dual_loss.load() +
regularization_loss.l1_loss + regularization_loss.l2_loss;
primal_loss() = (total_primal_loss.load() + regularization_loss.l1_loss +
regularization_loss.l2_loss) /
weighted_examples;