Fixes bug in accumulation of total-approximate-duality-gap.
Change: 115528686
This commit is contained in:
parent
73d557cc88
commit
e752109efb
@ -365,19 +365,20 @@ class SdcaSolver : public OpKernel {
|
|||||||
std::iota(example_ids.begin(), example_ids.end(), 0);
|
std::iota(example_ids.begin(), example_ids.end(), 0);
|
||||||
std::random_device random_device;
|
std::random_device random_device;
|
||||||
std::mt19937 random_generator(random_device());
|
std::mt19937 random_generator(random_device());
|
||||||
std::atomic<double> total_approx_duality_gap(
|
|
||||||
std::numeric_limits<double>::max());
|
|
||||||
std::atomic<double> total_primal_loss(0);
|
std::atomic<double> total_primal_loss(0);
|
||||||
|
std::atomic<double> total_dual_loss(0);
|
||||||
// Break when duality gap |P(w) - D(alpha)| is less than
|
// Break when duality gap |P(w) - D(alpha)| is less than
|
||||||
// duality_gap_threshold_
|
// duality_gap_threshold_
|
||||||
|
double total_approx_duality_gap = std::numeric_limits<double>::max();
|
||||||
while ((total_approx_duality_gap / weighted_examples) >
|
while ((total_approx_duality_gap / weighted_examples) >
|
||||||
duality_gap_threshold_) {
|
duality_gap_threshold_) {
|
||||||
// Reset and add everything.
|
// Reset accumulated losses.
|
||||||
total_approx_duality_gap = 0;
|
|
||||||
total_primal_loss = 0;
|
total_primal_loss = 0;
|
||||||
|
total_dual_loss = 0;
|
||||||
std::shuffle(example_ids.begin(), example_ids.end(), random_generator);
|
std::shuffle(example_ids.begin(), example_ids.end(), random_generator);
|
||||||
auto do_update = [&](const int64 begin, const int64 end) {
|
auto do_update = [&](const int64 begin, const int64 end) {
|
||||||
double approx_duality_gap = 0;
|
double dual_loss_on_example_subset = 0;
|
||||||
|
double primal_loss_on_example_subset = 0;
|
||||||
for (int64 offset = begin; offset < end; ++offset) {
|
for (int64 offset = begin; offset < end; ++offset) {
|
||||||
// Get example id, label, and weight.
|
// Get example id, label, and weight.
|
||||||
const int64 example_id = example_ids[offset];
|
const int64 example_id = example_ids[offset];
|
||||||
@ -400,9 +401,10 @@ class SdcaSolver : public OpKernel {
|
|||||||
|
|
||||||
const double dual_loss = compute_dual_loss_(
|
const double dual_loss = compute_dual_loss_(
|
||||||
current_dual, example_label, example_weight);
|
current_dual, example_label, example_weight);
|
||||||
|
dual_loss_on_example_subset += dual_loss;
|
||||||
const double primal_loss = compute_primal_loss_(
|
const double primal_loss = compute_primal_loss_(
|
||||||
per_example_data.wx, example_label, example_weight);
|
per_example_data.wx, example_label, example_weight);
|
||||||
approx_duality_gap += dual_loss + primal_loss;
|
primal_loss_on_example_subset += primal_loss;
|
||||||
|
|
||||||
// Update dual variable.
|
// Update dual variable.
|
||||||
dual_variables(example_id) = compute_dual_update_(
|
dual_variables(example_id) = compute_dual_update_(
|
||||||
@ -415,10 +417,9 @@ class SdcaSolver : public OpKernel {
|
|||||||
UpdateWeights(example_id, sparse_examples_by_index, dense_features,
|
UpdateWeights(example_id, sparse_examples_by_index, dense_features,
|
||||||
bounded_dual_delta, regularizations_.symmetric_l2,
|
bounded_dual_delta, regularizations_.symmetric_l2,
|
||||||
&sparse_weights_by_index, &dense_weights_by_index);
|
&sparse_weights_by_index, &dense_weights_by_index);
|
||||||
|
|
||||||
AtomicAdd(approx_duality_gap, &total_approx_duality_gap);
|
|
||||||
AtomicAdd(primal_loss, &total_primal_loss);
|
|
||||||
}
|
}
|
||||||
|
AtomicAdd(primal_loss_on_example_subset, &total_primal_loss);
|
||||||
|
AtomicAdd(dual_loss_on_example_subset, &total_dual_loss);
|
||||||
// TODO(rohananil): We may in the future want to make the primal-dual
|
// TODO(rohananil): We may in the future want to make the primal-dual
|
||||||
// relationship consistent as our current updates are not transactional.
|
// relationship consistent as our current updates are not transactional.
|
||||||
};
|
};
|
||||||
@ -430,9 +431,9 @@ class SdcaSolver : public OpKernel {
|
|||||||
kCostPerUnit, do_update);
|
kCostPerUnit, do_update);
|
||||||
const RegularizationLoss regularization_loss = ComputeRegularizationLoss(
|
const RegularizationLoss regularization_loss = ComputeRegularizationLoss(
|
||||||
sparse_weights_by_index, dense_weights_by_index, regularizations_);
|
sparse_weights_by_index, dense_weights_by_index, regularizations_);
|
||||||
total_approx_duality_gap.store(total_approx_duality_gap.load() +
|
total_approx_duality_gap =
|
||||||
regularization_loss.l1_loss +
|
total_primal_loss.load() + total_dual_loss.load() +
|
||||||
regularization_loss.l2_loss);
|
regularization_loss.l1_loss + regularization_loss.l2_loss;
|
||||||
primal_loss() = (total_primal_loss.load() + regularization_loss.l1_loss +
|
primal_loss() = (total_primal_loss.load() + regularization_loss.l1_loss +
|
||||||
regularization_loss.l2_loss) /
|
regularization_loss.l2_loss) /
|
||||||
weighted_examples;
|
weighted_examples;
|
||||||
|
Loading…
Reference in New Issue
Block a user