diff --git a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc index e0075a34fbc..fbd6090b704 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc +++ b/tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc @@ -365,19 +365,20 @@ class SdcaSolver : public OpKernel { std::iota(example_ids.begin(), example_ids.end(), 0); std::random_device random_device; std::mt19937 random_generator(random_device()); - std::atomic total_approx_duality_gap( - std::numeric_limits::max()); std::atomic total_primal_loss(0); + std::atomic total_dual_loss(0); // Break when duality gap |P(w) - D(alpha)| is less than // duality_gap_threshold_ + double total_approx_duality_gap = std::numeric_limits::max(); while ((total_approx_duality_gap / weighted_examples) > duality_gap_threshold_) { - // Reset and add everything. - total_approx_duality_gap = 0; + // Reset accumulated losses. total_primal_loss = 0; + total_dual_loss = 0; std::shuffle(example_ids.begin(), example_ids.end(), random_generator); auto do_update = [&](const int64 begin, const int64 end) { - double approx_duality_gap = 0; + double dual_loss_on_example_subset = 0; + double primal_loss_on_example_subset = 0; for (int64 offset = begin; offset < end; ++offset) { // Get example id, label, and weight. const int64 example_id = example_ids[offset]; @@ -400,9 +401,10 @@ class SdcaSolver : public OpKernel { const double dual_loss = compute_dual_loss_( current_dual, example_label, example_weight); + dual_loss_on_example_subset += dual_loss; const double primal_loss = compute_primal_loss_( per_example_data.wx, example_label, example_weight); - approx_duality_gap += dual_loss + primal_loss; + primal_loss_on_example_subset += primal_loss; // Update dual variable. dual_variables(example_id) = compute_dual_update_( @@ -415,10 +417,9 @@ class SdcaSolver : public OpKernel { UpdateWeights(example_id, sparse_examples_by_index, dense_features, bounded_dual_delta, regularizations_.symmetric_l2, &sparse_weights_by_index, &dense_weights_by_index); - - AtomicAdd(approx_duality_gap, &total_approx_duality_gap); - AtomicAdd(primal_loss, &total_primal_loss); } + AtomicAdd(primal_loss_on_example_subset, &total_primal_loss); + AtomicAdd(dual_loss_on_example_subset, &total_dual_loss); // TODO(rohananil): We may in the future want to make the primal-dual // relationship consistent as our current updates are not transactional. }; @@ -430,9 +431,9 @@ class SdcaSolver : public OpKernel { kCostPerUnit, do_update); const RegularizationLoss regularization_loss = ComputeRegularizationLoss( sparse_weights_by_index, dense_weights_by_index, regularizations_); - total_approx_duality_gap.store(total_approx_duality_gap.load() + - regularization_loss.l1_loss + - regularization_loss.l2_loss); + total_approx_duality_gap = + total_primal_loss.load() + total_dual_loss.load() + + regularization_loss.l1_loss + regularization_loss.l2_loss; primal_loss() = (total_primal_loss.load() + regularization_loss.l1_loss + regularization_loss.l2_loss) / weighted_examples;