60 lines
2.5 KiB
C++
60 lines
2.5 KiB
C++
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#ifndef TENSORFLOW_CORE_KERNELS_LOSS_H_
|
|
#define TENSORFLOW_CORE_KERNELS_LOSS_H_
|
|
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
class DualLossUpdater {
|
|
public:
|
|
virtual ~DualLossUpdater() {}
|
|
|
|
// Compute update dual (alpha), based on a single example. Various strategies
|
|
// can be employed here, like newton step and/or line search or approximate
|
|
// step that decreases the dual sub-optimality.
|
|
virtual double ComputeUpdatedDual(
|
|
const int num_loss_partitions, const double label,
|
|
const double example_weight, const double current_dual, const double wx,
|
|
const double weighted_example_norm) const = 0;
|
|
|
|
// Compute dual loss based on the current dual (alpha), example label (y)
|
|
// and example weight (cost).
|
|
virtual double ComputeDualLoss(const double current_dual,
|
|
const double example_label,
|
|
const double example_weight) const = 0;
|
|
|
|
// Compute the primal loss based on current estimate of log-odds(wx),
|
|
// example label (y) and example weight (cost).
|
|
virtual double ComputePrimalLoss(const double wx, const double example_label,
|
|
const double example_weight) const = 0;
|
|
|
|
// Primal loss derivative used to compute the dual residue in AdaSDCA
|
|
virtual double PrimalLossDerivative(const double wx,
|
|
const double example_label,
|
|
const double example_weight) const = 0;
|
|
|
|
// This is gamma such that the loss derivative is 1/gamma Lipschitz
|
|
virtual double SmoothnessConstant() const = 0;
|
|
|
|
// Converts binary example labels from 0.0 or 1.0 to appropriate range for
|
|
// each loss function.
|
|
virtual Status ConvertLabel(float* const example_label) const = 0;
|
|
};
|
|
|
|
} // namespace tensorflow
|
|
#endif // TENSORFLOW_CORE_KERNELS_LOSS_H_
|