Merge pull request #31164 from aprimostka:ctc-float64
PiperOrigin-RevId: 264272190
This commit is contained in:
commit
fe5cf47131
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -33,11 +34,12 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
|
||||
int* c) {
|
||||
template <typename T>
|
||||
inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r,
|
||||
int* c) {
|
||||
*c = 0;
|
||||
CHECK_LT(0, m.dimension(1));
|
||||
float p = m(r, 0);
|
||||
auto p = m(r, 0);
|
||||
for (int i = 1; i < m.dimension(1); ++i) {
|
||||
if (m(r, i) > p) {
|
||||
p = m(r, i);
|
||||
@ -170,6 +172,7 @@ class CTCDecodeHelper {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CTCGreedyDecoderOp : public OpKernel {
|
||||
public:
|
||||
explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -189,7 +192,7 @@ class CTCGreedyDecoderOp : public OpKernel {
|
||||
|
||||
const TensorShape& inputs_shape = inputs->shape();
|
||||
|
||||
std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
|
||||
std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
|
||||
const int64 max_time = inputs_shape.dim_size(0);
|
||||
const int64 batch_size = inputs_shape.dim_size(1);
|
||||
const int64 num_classes_raw = inputs_shape.dim_size(2);
|
||||
@ -198,14 +201,14 @@ class CTCGreedyDecoderOp : public OpKernel {
|
||||
errors::InvalidArgument("num_classes cannot exceed max int"));
|
||||
const int num_classes = static_cast<const int>(num_classes_raw);
|
||||
|
||||
auto inputs_t = inputs->tensor<float, 3>();
|
||||
auto inputs_t = inputs->tensor<T, 3>();
|
||||
|
||||
for (std::size_t t = 0; t < max_time; ++t) {
|
||||
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
|
||||
batch_size, num_classes);
|
||||
}
|
||||
auto seq_len_t = seq_len->vec<int32>();
|
||||
auto log_prob_t = log_prob->matrix<float>();
|
||||
auto log_prob_t = log_prob->matrix<T>();
|
||||
|
||||
log_prob_t.setZero();
|
||||
|
||||
@ -221,7 +224,8 @@ class CTCGreedyDecoderOp : public OpKernel {
|
||||
int prev_indices = -1;
|
||||
for (int t = 0; t < seq_len_t(b); ++t) {
|
||||
int max_class_indices;
|
||||
log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
|
||||
log_prob_t(b, 0) +=
|
||||
-RowMax<T>(input_list_t[t], b, &max_class_indices);
|
||||
if (max_class_indices != blank_index &&
|
||||
!(merge_repeated_ && max_class_indices == prev_indices)) {
|
||||
sequence.push_back(max_class_indices);
|
||||
@ -250,10 +254,18 @@ class CTCGreedyDecoderOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU),
|
||||
CTCGreedyDecoderOp);
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
CTCGreedyDecoderOp<T>);
|
||||
|
||||
REGISTER_CPU(float);
|
||||
REGISTER_CPU(double);
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
// CTC beam search
|
||||
template <typename T>
|
||||
class CTCBeamSearchDecoderOp : public OpKernel {
|
||||
public:
|
||||
explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -275,9 +287,9 @@ class CTCBeamSearchDecoderOp : public OpKernel {
|
||||
ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
|
||||
&decoded_values, &decoded_shape));
|
||||
|
||||
auto inputs_t = inputs->tensor<float, 3>();
|
||||
auto inputs_t = inputs->tensor<T, 3>();
|
||||
auto seq_len_t = seq_len->vec<int32>();
|
||||
auto log_prob_t = log_prob->matrix<float>();
|
||||
auto log_prob_t = log_prob->matrix<T>();
|
||||
|
||||
const TensorShape& inputs_shape = inputs->shape();
|
||||
|
||||
@ -291,21 +303,21 @@ class CTCBeamSearchDecoderOp : public OpKernel {
|
||||
|
||||
log_prob_t.setZero();
|
||||
|
||||
std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
|
||||
std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
|
||||
|
||||
for (std::size_t t = 0; t < max_time; ++t) {
|
||||
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
|
||||
batch_size, num_classes);
|
||||
}
|
||||
|
||||
ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
|
||||
&beam_scorer_, 1 /* batch_size */,
|
||||
merge_repeated_);
|
||||
Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
|
||||
auto input_chip_t = input_chip.flat<float>();
|
||||
ctc::CTCBeamSearchDecoder<T> beam_search(num_classes, beam_width_,
|
||||
&beam_scorer_, 1 /* batch_size */,
|
||||
merge_repeated_);
|
||||
Tensor input_chip(DataTypeToEnum<T>::v(), TensorShape({num_classes}));
|
||||
auto input_chip_t = input_chip.flat<T>();
|
||||
|
||||
std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
|
||||
std::vector<float> log_probs;
|
||||
std::vector<T> log_probs;
|
||||
|
||||
// Assumption: the blank index is num_classes - 1
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
@ -313,8 +325,8 @@ class CTCBeamSearchDecoderOp : public OpKernel {
|
||||
best_paths_b.resize(decode_helper_.GetTopPaths());
|
||||
for (int t = 0; t < seq_len_t(b); ++t) {
|
||||
input_chip_t = input_list_t[t].chip(b, 0);
|
||||
auto input_bi =
|
||||
Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
|
||||
auto input_bi = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>(
|
||||
input_chip_t.data(), num_classes);
|
||||
beam_search.Step(input_bi);
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
@ -335,13 +347,20 @@ class CTCBeamSearchDecoderOp : public OpKernel {
|
||||
|
||||
private:
|
||||
CTCDecodeHelper decode_helper_;
|
||||
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
|
||||
typename ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer beam_scorer_;
|
||||
bool merge_repeated_;
|
||||
int beam_width_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp<T>);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU),
|
||||
CTCBeamSearchDecoderOp);
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
CTCBeamSearchDecoderOp<T>);
|
||||
|
||||
REGISTER_CPU(float);
|
||||
REGISTER_CPU(double);
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -26,14 +27,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
template <typename T>
|
||||
class CTCLossOp : public OpKernel {
|
||||
typedef Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor> >
|
||||
typedef Eigen::Map<
|
||||
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
|
||||
InputMap;
|
||||
typedef Eigen::Map<
|
||||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
|
||||
OutputMap;
|
||||
|
||||
public:
|
||||
@ -110,7 +110,7 @@ class CTCLossOp : public OpKernel {
|
||||
errors::InvalidArgument("label SparseTensor is not valid: ",
|
||||
labels_sp_valid.error_message()));
|
||||
|
||||
ctc::CTCLossCalculator::LabelSequences labels_t(batch_size);
|
||||
typename ctc::CTCLossCalculator<T>::LabelSequences labels_t(batch_size);
|
||||
for (const auto& g : labels_sp.group({0})) { // iterate by batch
|
||||
const int64 batch_indices = g.group()[0];
|
||||
OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size),
|
||||
@ -137,13 +137,13 @@ class CTCLossOp : public OpKernel {
|
||||
|
||||
Tensor* loss = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
|
||||
auto loss_t = loss->vec<float>();
|
||||
auto loss_t = loss->vec<T>();
|
||||
|
||||
Tensor* gradient;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output("gradient", inputs_shape, &gradient));
|
||||
auto gradient_t = gradient->tensor<float, 3>();
|
||||
auto inputs_t = inputs->tensor<float, 3>();
|
||||
auto gradient_t = gradient->tensor<T, 3>();
|
||||
auto inputs_t = inputs->tensor<T, 3>();
|
||||
std::vector<OutputMap> gradient_list_t;
|
||||
std::vector<InputMap> input_list_t;
|
||||
|
||||
@ -158,7 +158,7 @@ class CTCLossOp : public OpKernel {
|
||||
gradient_t.setZero();
|
||||
|
||||
// Assumption: the blank index is num_classes - 1
|
||||
ctc::CTCLossCalculator ctc_loss_calculator(num_classes - 1, 0);
|
||||
ctc::CTCLossCalculator<T> ctc_loss_calculator(num_classes - 1, 0);
|
||||
DeviceBase::CpuWorkerThreads workers =
|
||||
*ctx->device()->tensorflow_cpu_worker_threads();
|
||||
OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
|
||||
@ -173,9 +173,17 @@ class CTCLossOp : public OpKernel {
|
||||
bool ctc_merge_repeated_;
|
||||
bool ignore_longer_outputs_than_inputs_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp<T>);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CTCLoss").Device(DEVICE_CPU), CTCLossOp);
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("CTCLoss").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
CTCLossOp<T>);
|
||||
|
||||
REGISTER_CPU(float);
|
||||
REGISTER_CPU(double);
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -25,15 +25,16 @@ using shape_inference::ShapeHandle;
|
||||
// CTC is Connectionist Temporal Classification. See util/ctc/ for details.
|
||||
|
||||
REGISTER_OP("CTCLoss")
|
||||
.Input("inputs: float")
|
||||
.Input("inputs: T")
|
||||
.Input("labels_indices: int64")
|
||||
.Input("labels_values: int32")
|
||||
.Input("sequence_length: int32")
|
||||
.Attr("preprocess_collapse_repeated: bool = false")
|
||||
.Attr("ctc_merge_repeated: bool = true")
|
||||
.Attr("ignore_longer_outputs_than_inputs: bool = false")
|
||||
.Output("loss: float")
|
||||
.Output("gradient: float")
|
||||
.Output("loss: T")
|
||||
.Output("gradient: T")
|
||||
.Attr("T: {float, double} = DT_FLOAT")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle inputs;
|
||||
ShapeHandle labels_indices;
|
||||
@ -62,13 +63,14 @@ REGISTER_OP("CTCLoss")
|
||||
});
|
||||
|
||||
REGISTER_OP("CTCGreedyDecoder")
|
||||
.Input("inputs: float")
|
||||
.Input("inputs: T")
|
||||
.Input("sequence_length: int32")
|
||||
.Attr("merge_repeated: bool = false")
|
||||
.Output("decoded_indices: int64")
|
||||
.Output("decoded_values: int64")
|
||||
.Output("decoded_shape: int64")
|
||||
.Output("log_probability: float")
|
||||
.Output("log_probability: T")
|
||||
.Attr("T: {float, double} = DT_FLOAT")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle inputs;
|
||||
ShapeHandle sequence_length;
|
||||
@ -90,7 +92,7 @@ REGISTER_OP("CTCGreedyDecoder")
|
||||
});
|
||||
|
||||
REGISTER_OP("CTCBeamSearchDecoder")
|
||||
.Input("inputs: float")
|
||||
.Input("inputs: T")
|
||||
.Input("sequence_length: int32")
|
||||
.Attr("beam_width: int >= 1")
|
||||
.Attr("top_paths: int >= 1")
|
||||
@ -98,7 +100,8 @@ REGISTER_OP("CTCBeamSearchDecoder")
|
||||
.Output("decoded_indices: top_paths * int64")
|
||||
.Output("decoded_values: top_paths * int64")
|
||||
.Output("decoded_shape: top_paths * int64")
|
||||
.Output("log_probability: float")
|
||||
.Output("log_probability: T")
|
||||
.Attr("T: {float, double} = DT_FLOAT")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle inputs;
|
||||
ShapeHandle sequence_length;
|
||||
|
@ -41,30 +41,34 @@ namespace ctc_beam_search {
|
||||
|
||||
struct EmptyBeamState {};
|
||||
|
||||
template <typename T>
|
||||
struct BeamProbability {
|
||||
BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {}
|
||||
BeamProbability()
|
||||
: total(kLogZero<T>::val),
|
||||
blank(kLogZero<T>::val),
|
||||
label(kLogZero<T>::val) {}
|
||||
void Reset() {
|
||||
total = kLogZero;
|
||||
blank = kLogZero;
|
||||
label = kLogZero;
|
||||
total = kLogZero<T>::val;
|
||||
blank = kLogZero<T>::val;
|
||||
label = kLogZero<T>::val;
|
||||
}
|
||||
float total;
|
||||
float blank;
|
||||
float label;
|
||||
T total;
|
||||
T blank;
|
||||
T label;
|
||||
};
|
||||
|
||||
template <class CTCBeamState>
|
||||
template <class T, class CTCBeamState>
|
||||
class BeamRoot;
|
||||
|
||||
template <class CTCBeamState = EmptyBeamState>
|
||||
template <class T, class CTCBeamState = EmptyBeamState>
|
||||
struct BeamEntry {
|
||||
// BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
|
||||
friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
|
||||
BeamEntry<CTCBeamState>* p, int l);
|
||||
inline bool Active() const { return newp.total != kLogZero; }
|
||||
friend BeamEntry<T, CTCBeamState>* BeamRoot<T, CTCBeamState>::AddEntry(
|
||||
BeamEntry<T, CTCBeamState>* p, int l);
|
||||
inline bool Active() const { return newp.total != kLogZero<T>::val; }
|
||||
// Return the child at the given index, or construct a new one in-place if
|
||||
// none was found.
|
||||
BeamEntry& GetChild(int ind) {
|
||||
BeamEntry<T, CTCBeamState>& GetChild(int ind) {
|
||||
auto entry = children.emplace(ind, nullptr);
|
||||
auto& child_entry = entry.first->second;
|
||||
// If this is a new child, populate the BeamEntry<CTCBeamState>*.
|
||||
@ -76,7 +80,7 @@ struct BeamEntry {
|
||||
std::vector<int> LabelSeq(bool merge_repeated) const {
|
||||
std::vector<int> labels;
|
||||
int prev_label = -1;
|
||||
const BeamEntry* c = this;
|
||||
const BeamEntry<T, CTCBeamState>* c = this;
|
||||
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
|
||||
if (!merge_repeated || c->label != prev_label) {
|
||||
labels.push_back(c->label);
|
||||
@ -88,12 +92,12 @@ struct BeamEntry {
|
||||
return labels;
|
||||
}
|
||||
|
||||
BeamEntry<CTCBeamState>* parent;
|
||||
BeamEntry<T, CTCBeamState>* parent;
|
||||
int label;
|
||||
// All instances of child BeamEntry are owned by *beam_root.
|
||||
gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children;
|
||||
BeamProbability oldp;
|
||||
BeamProbability newp;
|
||||
gtl::FlatMap<int, BeamEntry<T, CTCBeamState>*> children;
|
||||
BeamProbability<T> oldp;
|
||||
BeamProbability<T> newp;
|
||||
CTCBeamState state;
|
||||
|
||||
private:
|
||||
@ -102,40 +106,42 @@ struct BeamEntry {
|
||||
// otherwise parent will become invalid.
|
||||
// This private constructor is only called through the factory method
|
||||
// BeamRoot<CTCBeamState>::AddEntry().
|
||||
BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
|
||||
BeamEntry(BeamEntry* p, int l, BeamRoot<T, CTCBeamState>* beam_root)
|
||||
: parent(p), label(l), beam_root(beam_root) {}
|
||||
BeamRoot<CTCBeamState>* beam_root;
|
||||
BeamRoot<T, CTCBeamState>* beam_root;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
|
||||
};
|
||||
|
||||
// This class owns all instances of BeamEntry. This is used to avoid recursive
|
||||
// destructor call during destruction.
|
||||
template <class CTCBeamState = EmptyBeamState>
|
||||
template <class T, class CTCBeamState = EmptyBeamState>
|
||||
class BeamRoot {
|
||||
public:
|
||||
BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
|
||||
BeamRoot(BeamEntry<T, CTCBeamState>* p, int l) {
|
||||
root_entry_ = AddEntry(p, l);
|
||||
}
|
||||
BeamRoot(const BeamRoot&) = delete;
|
||||
BeamRoot& operator=(const BeamRoot&) = delete;
|
||||
|
||||
BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
|
||||
auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
|
||||
BeamEntry<T, CTCBeamState>* AddEntry(BeamEntry<T, CTCBeamState>* p, int l) {
|
||||
auto* new_entry = new BeamEntry<T, CTCBeamState>(p, l, this);
|
||||
beam_entries_.emplace_back(new_entry);
|
||||
return new_entry;
|
||||
}
|
||||
BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
|
||||
BeamEntry<T, CTCBeamState>* RootEntry() const { return root_entry_; }
|
||||
|
||||
private:
|
||||
BeamEntry<CTCBeamState>* root_entry_ = nullptr;
|
||||
std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
|
||||
BeamEntry<T, CTCBeamState>* root_entry_ = nullptr;
|
||||
std::vector<std::unique_ptr<BeamEntry<T, CTCBeamState>>> beam_entries_;
|
||||
};
|
||||
|
||||
// BeamComparer is the default beam comparer provided in CTCBeamSearch.
|
||||
template <class CTCBeamState = EmptyBeamState>
|
||||
template <class T, class CTCBeamState = EmptyBeamState>
|
||||
class BeamComparer {
|
||||
public:
|
||||
virtual ~BeamComparer() {}
|
||||
virtual bool inline operator()(const BeamEntry<CTCBeamState>* a,
|
||||
const BeamEntry<CTCBeamState>* b) const {
|
||||
virtual bool inline operator()(const BeamEntry<T, CTCBeamState>* a,
|
||||
const BeamEntry<T, CTCBeamState>* b) const {
|
||||
return a->newp.total > b->newp.total;
|
||||
}
|
||||
};
|
||||
|
@ -34,7 +34,7 @@ namespace ctc {
|
||||
// be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex
|
||||
// scoring is required. Its main purpose is to provide a thin layer for
|
||||
// integrating language model scoring easily.
|
||||
template <typename CTCBeamState>
|
||||
template <typename T, typename CTCBeamState>
|
||||
class BaseBeamScorer {
|
||||
public:
|
||||
virtual ~BaseBeamScorer() {}
|
||||
@ -56,8 +56,8 @@ class BaseBeamScorer {
|
||||
//
|
||||
// The score returned should be a log-probability. In the simplest case, as
|
||||
// there's no state expansion logic, the expansion score is zero.
|
||||
virtual float GetStateExpansionScore(const CTCBeamState& state,
|
||||
float previous_score) const {
|
||||
virtual T GetStateExpansionScore(const CTCBeamState& state,
|
||||
T previous_score) const {
|
||||
return previous_score;
|
||||
}
|
||||
// GetStateEndExpansionScore should be an inexpensive method to retrieve the
|
||||
@ -65,8 +65,8 @@ class BaseBeamScorer {
|
||||
// multiplied (log-addition) with the final probability of the beam.
|
||||
//
|
||||
// The score returned should be a log-probability.
|
||||
virtual float GetStateEndExpansionScore(const CTCBeamState& state) const {
|
||||
return 0;
|
||||
virtual T GetStateEndExpansionScore(const CTCBeamState& state) const {
|
||||
return T(0);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -38,10 +38,10 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
|
||||
template <typename T, typename CTCBeamState = ctc_beam_search::EmptyBeamState,
|
||||
typename CTCBeamComparer =
|
||||
ctc_beam_search::BeamComparer<CTCBeamState>>
|
||||
class CTCBeamSearchDecoder : public CTCDecoder {
|
||||
ctc_beam_search::BeamComparer<T, CTCBeamState>>
|
||||
class CTCBeamSearchDecoder : public CTCDecoder<T> {
|
||||
// Beam Search
|
||||
//
|
||||
// Example (GravesTh Fig. 7.5):
|
||||
@ -73,12 +73,12 @@ class CTCBeamSearchDecoder : public CTCDecoder {
|
||||
// starts at 0). This special case can be calculated as:
|
||||
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
|
||||
// but we calculate it recursively for speed purposes.
|
||||
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
|
||||
typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
|
||||
typedef ctc_beam_search::BeamProbability BeamProbability;
|
||||
typedef ctc_beam_search::BeamEntry<T, CTCBeamState> BeamEntry;
|
||||
typedef ctc_beam_search::BeamRoot<T, CTCBeamState> BeamRoot;
|
||||
typedef ctc_beam_search::BeamProbability<T> BeamProbability;
|
||||
|
||||
public:
|
||||
typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
|
||||
typedef BaseBeamScorer<T, CTCBeamState> DefaultBeamScorer;
|
||||
|
||||
// The beam search decoder is constructed specifying the beam_width (number of
|
||||
// candidates to keep at each decoding timestep) and a beam scorer (used for
|
||||
@ -87,9 +87,9 @@ class CTCBeamSearchDecoder : public CTCDecoder {
|
||||
// implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
|
||||
// standard beam search.
|
||||
CTCBeamSearchDecoder(int num_classes, int beam_width,
|
||||
BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
|
||||
bool merge_repeated = false)
|
||||
: CTCDecoder(num_classes, batch_size, merge_repeated),
|
||||
BaseBeamScorer<T, CTCBeamState>* scorer,
|
||||
int batch_size = 1, bool merge_repeated = false)
|
||||
: CTCDecoder<T>(num_classes, batch_size, merge_repeated),
|
||||
beam_width_(beam_width),
|
||||
leaves_(beam_width),
|
||||
beam_scorer_(CHECK_NOTNULL(scorer)) {
|
||||
@ -99,27 +99,28 @@ class CTCBeamSearchDecoder : public CTCDecoder {
|
||||
~CTCBeamSearchDecoder() override {}
|
||||
|
||||
// Run the hibernating beam search algorithm on the given input.
|
||||
Status Decode(const CTCDecoder::SequenceLength& seq_len,
|
||||
const std::vector<CTCDecoder::Input>& input,
|
||||
std::vector<CTCDecoder::Output>* output,
|
||||
CTCDecoder::ScoreOutput* scores) override;
|
||||
Status Decode(const typename CTCDecoder<T>::SequenceLength& seq_len,
|
||||
const std::vector<typename CTCDecoder<T>::Input>& input,
|
||||
std::vector<typename CTCDecoder<T>::Output>* output,
|
||||
typename CTCDecoder<T>::ScoreOutput* scores) override;
|
||||
|
||||
// Calculate the next step of the beam search and update the internal state.
|
||||
template <typename Vector>
|
||||
void Step(const Vector& log_input_t);
|
||||
|
||||
template <typename Vector>
|
||||
float GetTopK(const int K, const Vector& input,
|
||||
std::vector<float>* top_k_logits,
|
||||
std::vector<int>* top_k_indices);
|
||||
T GetTopK(const int K, const Vector& input, std::vector<T>* top_k_logits,
|
||||
std::vector<int>* top_k_indices);
|
||||
|
||||
// Retrieve the beam scorer instance used during decoding.
|
||||
BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
|
||||
BaseBeamScorer<T, CTCBeamState>* GetBeamScorer() const {
|
||||
return beam_scorer_;
|
||||
}
|
||||
|
||||
// Set label selection parameters for faster decoding.
|
||||
// See comments for label_selection_size_ and label_selection_margin_.
|
||||
void SetLabelSelectionParameters(int label_selection_size,
|
||||
float label_selection_margin) {
|
||||
T label_selection_margin) {
|
||||
label_selection_size_ = label_selection_size;
|
||||
label_selection_margin_ = label_selection_margin;
|
||||
}
|
||||
@ -129,7 +130,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
|
||||
|
||||
// Extract the top n paths at current time step
|
||||
Status TopPaths(int n, std::vector<std::vector<int>>* paths,
|
||||
std::vector<float>* log_probs, bool merge_repeated) const;
|
||||
std::vector<T>* log_probs, bool merge_repeated) const;
|
||||
|
||||
private:
|
||||
int beam_width_;
|
||||
@ -145,37 +146,38 @@ class CTCBeamSearchDecoder : public CTCDecoder {
|
||||
// Default is to do no label selection.
|
||||
// For more detail: https://research.google.com/pubs/pub44823.html
|
||||
int label_selection_size_ = 0; // zero means unlimited
|
||||
float label_selection_margin_ = -1; // -1 means unlimited.
|
||||
T label_selection_margin_ = -1; // -1 means unlimited.
|
||||
|
||||
gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
|
||||
std::unique_ptr<BeamRoot> beam_root_;
|
||||
BaseBeamScorer<CTCBeamState>* beam_scorer_;
|
||||
BaseBeamScorer<T, CTCBeamState>* beam_scorer_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
|
||||
};
|
||||
|
||||
template <typename CTCBeamState, typename CTCBeamComparer>
|
||||
Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
|
||||
const CTCDecoder::SequenceLength& seq_len,
|
||||
const std::vector<CTCDecoder::Input>& input,
|
||||
std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
|
||||
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
|
||||
Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Decode(
|
||||
const typename CTCDecoder<T>::SequenceLength& seq_len,
|
||||
const std::vector<typename CTCDecoder<T>::Input>& input,
|
||||
std::vector<typename CTCDecoder<T>::Output>* output,
|
||||
typename CTCDecoder<T>::ScoreOutput* scores) {
|
||||
// Storage for top paths.
|
||||
std::vector<std::vector<int>> beams;
|
||||
std::vector<float> beam_log_probabilities;
|
||||
std::vector<T> beam_log_probabilities;
|
||||
int top_n = output->size();
|
||||
if (std::any_of(output->begin(), output->end(),
|
||||
[this](const CTCDecoder::Output& output) -> bool {
|
||||
[this](const typename CTCDecoder<T>::Output& output) -> bool {
|
||||
return output.size() < this->batch_size_;
|
||||
})) {
|
||||
return errors::InvalidArgument(
|
||||
"output needs to be of size at least (top_n, batch_size).");
|
||||
}
|
||||
if (scores->rows() < batch_size_ || scores->cols() < top_n) {
|
||||
if (scores->rows() < this->batch_size_ || scores->cols() < top_n) {
|
||||
return errors::InvalidArgument(
|
||||
"scores needs to be of size at least (batch_size, top_n).");
|
||||
}
|
||||
|
||||
for (int b = 0; b < batch_size_; ++b) {
|
||||
for (int b = 0; b < this->batch_size_; ++b) {
|
||||
int seq_len_b = seq_len[b];
|
||||
Reset();
|
||||
|
||||
@ -196,7 +198,7 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
|
||||
}
|
||||
|
||||
Status status =
|
||||
TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
|
||||
TopPaths(top_n, &beams, &beam_log_probabilities, this->merge_repeated_);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
@ -213,20 +215,20 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename CTCBeamState, typename CTCBeamComparer>
|
||||
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
|
||||
template <typename Vector>
|
||||
float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
|
||||
const int K, const Vector& input, std::vector<float>* top_k_logits,
|
||||
T CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::GetTopK(
|
||||
const int K, const Vector& input, std::vector<T>* top_k_logits,
|
||||
std::vector<int>* top_k_indices) {
|
||||
// Find Top K choices, complexity nk in worst case. The array input is read
|
||||
// just once.
|
||||
CHECK_EQ(num_classes_, input.size());
|
||||
CHECK_EQ(this->num_classes_, input.size());
|
||||
top_k_logits->clear();
|
||||
top_k_indices->clear();
|
||||
top_k_logits->resize(K, -INFINITY);
|
||||
top_k_indices->resize(K, -1);
|
||||
for (int j = 0; j < num_classes_ - 1; ++j) {
|
||||
const float logit = input(j);
|
||||
for (int j = 0; j < this->num_classes_ - 1; ++j) {
|
||||
const T logit = input(j);
|
||||
if (logit > (*top_k_logits)[K - 1]) {
|
||||
int k = K - 1;
|
||||
while (k > 0 && logit > (*top_k_logits)[k - 1]) {
|
||||
@ -239,43 +241,43 @@ float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
|
||||
}
|
||||
}
|
||||
// Return max value which is in 0th index or blank character logit
|
||||
return std::max((*top_k_logits)[0], input(num_classes_ - 1));
|
||||
return std::max((*top_k_logits)[0], input(this->num_classes_ - 1));
|
||||
}
|
||||
|
||||
template <typename CTCBeamState, typename CTCBeamComparer>
|
||||
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
|
||||
template <typename Vector>
|
||||
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
||||
void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Step(
|
||||
const Vector& raw_input) {
|
||||
std::vector<float> top_k_logits;
|
||||
std::vector<T> top_k_logits;
|
||||
std::vector<int> top_k_indices;
|
||||
const bool top_k =
|
||||
(label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
|
||||
// Number of character classes to consider in each step.
|
||||
const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
|
||||
const int max_classes =
|
||||
top_k ? label_selection_size_ : (this->num_classes_ - 1);
|
||||
// Get max coefficient and remove it from raw_input later.
|
||||
float max_coeff;
|
||||
T max_coeff;
|
||||
if (top_k) {
|
||||
max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
|
||||
&top_k_indices);
|
||||
} else {
|
||||
max_coeff = raw_input.maxCoeff();
|
||||
}
|
||||
|
||||
// Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
|
||||
float logsumexp = 0.0;
|
||||
T logsumexp = T(0.0);
|
||||
for (int j = 0; j < raw_input.size(); ++j) {
|
||||
logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
|
||||
}
|
||||
logsumexp = Eigen::numext::log(logsumexp);
|
||||
// Final normalization offset to get correct log probabilities.
|
||||
float norm_offset = max_coeff + logsumexp;
|
||||
T norm_offset = max_coeff + logsumexp;
|
||||
|
||||
const float label_selection_input_min =
|
||||
const T label_selection_input_min =
|
||||
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
|
||||
: -std::numeric_limits<float>::infinity();
|
||||
: -std::numeric_limits<T>::infinity();
|
||||
|
||||
// Extract the beams sorted in decreasing new probability
|
||||
CHECK_EQ(num_classes_, raw_input.size());
|
||||
CHECK_EQ(this->num_classes_, raw_input.size());
|
||||
|
||||
std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
|
||||
leaves_.Reset();
|
||||
@ -294,8 +296,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
||||
// else:
|
||||
// Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
|
||||
// + P(l=ab @ t=5))
|
||||
float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
|
||||
: b->parent->oldp.total;
|
||||
T previous = (b->label == b->parent->label) ? b->parent->oldp.blank
|
||||
: b->parent->oldp.total;
|
||||
b->newp.label =
|
||||
LogSumExp(b->newp.label,
|
||||
beam_scorer_->GetStateExpansionScore(b->state, previous));
|
||||
@ -304,7 +306,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
||||
b->newp.label += raw_input(b->label) - norm_offset;
|
||||
}
|
||||
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
|
||||
b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
|
||||
b->newp.blank = b->oldp.total + raw_input(this->blank_index_) - norm_offset;
|
||||
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
|
||||
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
|
||||
|
||||
@ -325,7 +327,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
||||
// isn't full, or the lowest probability entry in the beam has a
|
||||
// lower probability than the leaf.
|
||||
auto is_candidate = [this](const BeamProbability& prob) {
|
||||
return (prob.total > kLogZero &&
|
||||
return (prob.total > kLogZero<T>::val &&
|
||||
(leaves_.size() < beam_width_ ||
|
||||
prob.total > leaves_.peek_bottom()->newp.total));
|
||||
};
|
||||
@ -336,7 +338,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
||||
|
||||
for (int ind = 0; ind < max_classes; ind++) {
|
||||
const int label = top_k ? top_k_indices[ind] : ind;
|
||||
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
|
||||
const T logit = top_k ? top_k_logits[ind] : raw_input(ind);
|
||||
// Perform label selection: if input for this label looks very
|
||||
// unpromising, never evaluate it with a scorer.
|
||||
// We may compare logits instead of log probabilities,
|
||||
@ -347,13 +349,13 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
||||
BeamEntry& c = b->GetChild(label);
|
||||
if (!c.Active()) {
|
||||
// Pblank(l=abcd @ t=6) = 0
|
||||
c.newp.blank = kLogZero;
|
||||
c.newp.blank = kLogZero<T>::val;
|
||||
// If new child label is identical to beam label:
|
||||
// Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
|
||||
// Otherwise:
|
||||
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
|
||||
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
|
||||
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
|
||||
T previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
|
||||
c.newp.label = logit - norm_offset +
|
||||
beam_scorer_->GetStateExpansionScore(c.state, previous);
|
||||
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
|
||||
@ -379,15 +381,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
||||
} // for (BeamEntry* b...
|
||||
}
|
||||
|
||||
template <typename CTCBeamState, typename CTCBeamComparer>
|
||||
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
|
||||
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
|
||||
void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Reset() {
|
||||
leaves_.Reset();
|
||||
|
||||
// This beam root, and all of its children, will be in memory until
|
||||
// the next reset.
|
||||
beam_root_.reset(new BeamRoot(nullptr, -1));
|
||||
beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
|
||||
beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
|
||||
beam_root_->RootEntry()->newp.total = T(0.0); // ln(1)
|
||||
beam_root_->RootEntry()->newp.blank = T(0.0); // ln(1)
|
||||
|
||||
// Add the root as the initial leaf.
|
||||
leaves_.push(beam_root_->RootEntry());
|
||||
@ -396,9 +398,9 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
|
||||
beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
|
||||
}
|
||||
|
||||
template <typename CTCBeamState, typename CTCBeamComparer>
|
||||
Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
|
||||
int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
|
||||
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
|
||||
Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::TopPaths(
|
||||
int n, std::vector<std::vector<int>>* paths, std::vector<T>* log_probs,
|
||||
bool merge_repeated) const {
|
||||
CHECK_NOTNULL(paths)->clear();
|
||||
CHECK_NOTNULL(log_probs)->clear();
|
||||
|
@ -19,19 +19,20 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
typedef std::vector<std::vector<std::vector<float>>> TestData;
|
||||
using tensorflow::ctc::CTCBeamSearchDecoder;
|
||||
using tensorflow::ctc::CTCDecoder;
|
||||
template <class T>
|
||||
using TestData = std::vector<std::vector<std::vector<T>>>;
|
||||
|
||||
// The HistoryBeamState is used to keep track of the current candidate and
|
||||
// caches the expansion score (needed by the scorer).
|
||||
template <class T>
|
||||
struct HistoryBeamState {
|
||||
float score;
|
||||
T score;
|
||||
std::vector<int> labels;
|
||||
};
|
||||
|
||||
@ -40,48 +41,48 @@ struct HistoryBeamState {
|
||||
// a prefix of a dictionary word it gets a low probability at each step.
|
||||
//
|
||||
// The dictionary itself is hard-coded a static const variable of the class.
|
||||
template <class T, class BeamState>
|
||||
class DictionaryBeamScorer
|
||||
: public tensorflow::ctc::BaseBeamScorer<HistoryBeamState> {
|
||||
: public tensorflow::ctc::BaseBeamScorer<T, BeamState> {
|
||||
public:
|
||||
void InitializeState(HistoryBeamState* root) const override {
|
||||
root->score = 0;
|
||||
}
|
||||
DictionaryBeamScorer()
|
||||
: tensorflow::ctc::BaseBeamScorer<T, BeamState>(),
|
||||
dictionary_({{3}, {3, 1}}) {}
|
||||
|
||||
void ExpandState(const HistoryBeamState& from_state, int from_label,
|
||||
HistoryBeamState* to_state, int to_label) const override {
|
||||
void InitializeState(BeamState* root) const override { root->score = 0; }
|
||||
|
||||
void ExpandState(const BeamState& from_state, int from_label,
|
||||
BeamState* to_state, int to_label) const override {
|
||||
// Keep track of the current complete candidate by storing the labels along
|
||||
// the expansion path in the beam state.
|
||||
to_state->labels.push_back(to_label);
|
||||
SetStateScoreAccordingToDict(to_state);
|
||||
}
|
||||
|
||||
void ExpandStateEnd(HistoryBeamState* state) const override {
|
||||
void ExpandStateEnd(BeamState* state) const override {
|
||||
SetStateScoreAccordingToDict(state);
|
||||
}
|
||||
|
||||
float GetStateExpansionScore(const HistoryBeamState& state,
|
||||
float previous_score) const override {
|
||||
T GetStateExpansionScore(const BeamState& state,
|
||||
T previous_score) const override {
|
||||
return previous_score + state.score;
|
||||
}
|
||||
|
||||
float GetStateEndExpansionScore(
|
||||
const HistoryBeamState& state) const override {
|
||||
T GetStateEndExpansionScore(const BeamState& state) const override {
|
||||
return state.score;
|
||||
}
|
||||
|
||||
// Simple dictionary used when scoring the beams to check if they are prefixes
|
||||
// of dictionary words (see SetStateScoreAccordingToDict below).
|
||||
static const std::vector<std::vector<int>> dictionary_;
|
||||
const std::vector<std::vector<int>> dictionary_;
|
||||
|
||||
private:
|
||||
void SetStateScoreAccordingToDict(HistoryBeamState* state) const;
|
||||
void SetStateScoreAccordingToDict(BeamState* state) const;
|
||||
};
|
||||
|
||||
const std::vector<std::vector<int>> DictionaryBeamScorer::dictionary_ = {
|
||||
{3}, {3, 1}};
|
||||
|
||||
void DictionaryBeamScorer::SetStateScoreAccordingToDict(
|
||||
HistoryBeamState* state) const {
|
||||
template <class T, class BeamState>
|
||||
void DictionaryBeamScorer<T, BeamState>::SetStateScoreAccordingToDict(
|
||||
BeamState* state) const {
|
||||
// Check if the beam can still be a dictionary word (e.g. prefix of one).
|
||||
const std::vector<int>& candidate = state->labels;
|
||||
for (int w = 0; w < dictionary_.size(); ++w) {
|
||||
@ -92,32 +93,35 @@ void DictionaryBeamScorer::SetStateScoreAccordingToDict(
|
||||
}
|
||||
if (std::equal(word.begin(), word.begin() + candidate.size(),
|
||||
candidate.begin())) {
|
||||
state->score = std::log(1.0);
|
||||
state->score = std::log(T(1.0));
|
||||
return;
|
||||
}
|
||||
}
|
||||
// At this point, the candidate certainly can't be in the dictionary.
|
||||
state->score = std::log(0.01);
|
||||
state->score = std::log(T(0.01));
|
||||
}
|
||||
|
||||
TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
|
||||
template <class T>
|
||||
void ctc_beam_search_decoding_with_and_without_dictionary() {
|
||||
const int batch_size = 1;
|
||||
const int timesteps = 5;
|
||||
const int top_paths = 3;
|
||||
const int num_classes = 6;
|
||||
|
||||
// Plain decoder using hibernating beam search algorithm.
|
||||
CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer;
|
||||
CTCBeamSearchDecoder<> decoder(num_classes, 10 * top_paths, &default_scorer);
|
||||
typename tensorflow::ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer
|
||||
default_scorer;
|
||||
tensorflow::ctc::CTCBeamSearchDecoder<T> decoder(num_classes, 10 * top_paths,
|
||||
&default_scorer);
|
||||
|
||||
// Dictionary decoder, allowing only two dictionary words : {3}, {3, 1}.
|
||||
DictionaryBeamScorer dictionary_scorer;
|
||||
CTCBeamSearchDecoder<HistoryBeamState> dictionary_decoder(
|
||||
num_classes, top_paths, &dictionary_scorer);
|
||||
DictionaryBeamScorer<T, HistoryBeamState<T>> dictionary_scorer;
|
||||
tensorflow::ctc::CTCBeamSearchDecoder<T, HistoryBeamState<T>>
|
||||
dictionary_decoder(num_classes, top_paths, &dictionary_scorer);
|
||||
|
||||
// Raw data containers (arrays of floats, ints, etc.).
|
||||
// Raw data containers (arrays of floats64, ints, etc.).
|
||||
int sequence_lengths[batch_size] = {timesteps};
|
||||
float input_data_mat[timesteps][batch_size][num_classes] = {
|
||||
T input_data_mat[timesteps][batch_size][num_classes] = {
|
||||
{{0, 0.6, 0, 0.4, 0, 0}},
|
||||
{{0, 0.5, 0, 0.5, 0, 0}},
|
||||
{{0, 0.4, 0, 0.6, 0, 0}},
|
||||
@ -134,34 +138,40 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
|
||||
}
|
||||
|
||||
// Plain output, without any additional scoring.
|
||||
std::vector<CTCDecoder::Output> expected_output = {
|
||||
{{1, 3}, {1, 3, 1}, {3, 1, 3}},
|
||||
};
|
||||
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> expected_output =
|
||||
{
|
||||
{{1, 3}, {1, 3, 1}, {3, 1, 3}},
|
||||
};
|
||||
|
||||
// Dictionary outputs: preference for dictionary candidates. The
|
||||
// second-candidate is there, despite it not being a dictionary word, due to
|
||||
// stronger probability in the input to the decoder.
|
||||
std::vector<CTCDecoder::Output> expected_dict_output = {
|
||||
{{3}, {1, 3}, {3, 1}},
|
||||
};
|
||||
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output>
|
||||
expected_dict_output = {
|
||||
{{3}, {1, 3}, {3, 1}},
|
||||
};
|
||||
|
||||
// Convert data containers to the format accepted by the decoder, simply
|
||||
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
|
||||
// using Eigen::Map.
|
||||
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
|
||||
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
|
||||
std::vector<
|
||||
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>>
|
||||
inputs;
|
||||
inputs.reserve(timesteps);
|
||||
for (int t = 0; t < timesteps; ++t) {
|
||||
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
|
||||
}
|
||||
|
||||
// Prepare containers for output and scores.
|
||||
std::vector<CTCDecoder::Output> outputs(top_paths);
|
||||
for (CTCDecoder::Output& output : outputs) {
|
||||
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> outputs(
|
||||
top_paths);
|
||||
for (typename tensorflow::ctc::CTCDecoder<T>::Output& output : outputs) {
|
||||
output.resize(batch_size);
|
||||
}
|
||||
float score[batch_size][top_paths] = {{0.0}};
|
||||
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
|
||||
T score[batch_size][top_paths] = {{0.0}};
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> scores(
|
||||
&score[0][0], batch_size, top_paths);
|
||||
|
||||
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
|
||||
for (int path = 0; path < top_paths; ++path) {
|
||||
@ -169,8 +179,9 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
|
||||
}
|
||||
|
||||
// Prepare dictionary outputs.
|
||||
std::vector<CTCDecoder::Output> dict_outputs(top_paths);
|
||||
for (CTCDecoder::Output& output : dict_outputs) {
|
||||
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> dict_outputs(
|
||||
top_paths);
|
||||
for (typename tensorflow::ctc::CTCDecoder<T>::Output& output : dict_outputs) {
|
||||
output.resize(batch_size);
|
||||
}
|
||||
EXPECT_TRUE(
|
||||
@ -180,38 +191,45 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) {
|
||||
template <class T>
|
||||
void ctc_beam_search_decoding_all_beam_elements_have_finite_scores() {
|
||||
const int batch_size = 1;
|
||||
const int timesteps = 1;
|
||||
const int top_paths = 3;
|
||||
const int num_classes = 6;
|
||||
|
||||
// Plain decoder using hibernating beam search algorithm.
|
||||
CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer;
|
||||
CTCBeamSearchDecoder<> decoder(num_classes, top_paths, &default_scorer);
|
||||
typename tensorflow::ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer
|
||||
default_scorer;
|
||||
tensorflow::ctc::CTCBeamSearchDecoder<T> decoder(num_classes, top_paths,
|
||||
&default_scorer);
|
||||
|
||||
// Raw data containers (arrays of floats, ints, etc.).
|
||||
// Raw data containers (arrays of floats64, ints, etc.).
|
||||
int sequence_lengths[batch_size] = {timesteps};
|
||||
float input_data_mat[timesteps][batch_size][num_classes] = {
|
||||
T input_data_mat[timesteps][batch_size][num_classes] = {
|
||||
{{0.4, 0.3, 0, 0, 0, 0.5}}};
|
||||
|
||||
// Convert data containers to the format accepted by the decoder, simply
|
||||
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
|
||||
// using Eigen::Map.
|
||||
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
|
||||
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
|
||||
std::vector<
|
||||
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>>
|
||||
inputs;
|
||||
inputs.reserve(timesteps);
|
||||
for (int t = 0; t < timesteps; ++t) {
|
||||
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
|
||||
}
|
||||
|
||||
// Prepare containers for output and scores.
|
||||
std::vector<CTCDecoder::Output> outputs(top_paths);
|
||||
for (CTCDecoder::Output& output : outputs) {
|
||||
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> outputs(
|
||||
top_paths);
|
||||
for (typename tensorflow::ctc::CTCDecoder<T>::Output& output : outputs) {
|
||||
output.resize(batch_size);
|
||||
}
|
||||
float score[batch_size][top_paths] = {{0.0}};
|
||||
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
|
||||
T score[batch_size][top_paths] = {{0.0}};
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> scores(
|
||||
&score[0][0], batch_size, top_paths);
|
||||
|
||||
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
|
||||
// Make sure all scores are finite.
|
||||
@ -226,8 +244,9 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) {
|
||||
|
||||
typedef int LabelState; // The state is simply the final label.
|
||||
|
||||
template <class T>
|
||||
class RapidlyDroppingLabelScorer
|
||||
: public tensorflow::ctc::BaseBeamScorer<LabelState> {
|
||||
: public tensorflow::ctc::BaseBeamScorer<T, LabelState> {
|
||||
public:
|
||||
void InitializeState(LabelState* root) const override {}
|
||||
|
||||
@ -238,75 +257,84 @@ class RapidlyDroppingLabelScorer
|
||||
|
||||
void ExpandStateEnd(LabelState* state) const override {}
|
||||
|
||||
float GetStateExpansionScore(const LabelState& state,
|
||||
float previous_score) const override {
|
||||
T GetStateExpansionScore(const LabelState& state,
|
||||
T previous_score) const override {
|
||||
// Drop off rapidly for later labels.
|
||||
const float kRapidly = 100;
|
||||
const T kRapidly = 100;
|
||||
return previous_score - kRapidly * state;
|
||||
}
|
||||
|
||||
float GetStateEndExpansionScore(const LabelState& state) const override {
|
||||
return 0;
|
||||
T GetStateEndExpansionScore(const LabelState& state) const override {
|
||||
return T(0);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(CtcBeamSearch, LabelSelection) {
|
||||
template <class T>
|
||||
void ctc_beam_search_label_selection() {
|
||||
const int batch_size = 1;
|
||||
const int timesteps = 3;
|
||||
const int top_paths = 5;
|
||||
const int num_classes = 6;
|
||||
|
||||
// Decoder which drops off log-probabilities for labels 0 >> 1 >> 2 >> 3.
|
||||
RapidlyDroppingLabelScorer scorer;
|
||||
CTCBeamSearchDecoder<LabelState> decoder(num_classes, top_paths, &scorer);
|
||||
RapidlyDroppingLabelScorer<T> scorer;
|
||||
tensorflow::ctc::CTCBeamSearchDecoder<T, LabelState> decoder(
|
||||
num_classes, top_paths, &scorer);
|
||||
|
||||
// Raw data containers (arrays of floats, ints, etc.).
|
||||
// Raw data containers (arrays of floats64, ints, etc.).
|
||||
int sequence_lengths[batch_size] = {timesteps};
|
||||
// Log probabilities, slightly preferring later labels, this decision
|
||||
// should be overridden by the scorer which strongly prefers earlier labels.
|
||||
// The last one is empty label, and for simplicity we give it an extremely
|
||||
// high cost to ignore it. We also use the first label to break up the
|
||||
// repeated label sequence.
|
||||
float input_data_mat[timesteps][batch_size][num_classes] = {
|
||||
T input_data_mat[timesteps][batch_size][num_classes] = {
|
||||
{{-1e6, 1, 2, 3, 4, -1e6}},
|
||||
{{1e6, 0, 0, 0, 0, -1e6}}, // force label 0 to break up repeated
|
||||
{{-1e6, 1.1, 2.2, 3.3, 4.4, -1e6}},
|
||||
};
|
||||
|
||||
// Expected output without label selection
|
||||
std::vector<CTCDecoder::Output> expected_default_output = {
|
||||
{{1, 0, 1}, {1, 0, 2}, {2, 0, 1}, {1, 0, 3}, {2, 0, 2}},
|
||||
};
|
||||
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output>
|
||||
expected_default_output = {
|
||||
{{1, 0, 1}, {1, 0, 2}, {2, 0, 1}, {1, 0, 3}, {2, 0, 2}},
|
||||
};
|
||||
|
||||
// Expected output with label selection limiting to 2 items
|
||||
// this is suboptimal because only labels 3 and 4 were allowed to be seen.
|
||||
std::vector<CTCDecoder::Output> expected_output_size2 = {
|
||||
{{3, 0, 3}, {3, 0, 4}, {4, 0, 3}, {4, 0, 4}, {3}},
|
||||
};
|
||||
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output>
|
||||
expected_output_size2 = {
|
||||
{{3, 0, 3}, {3, 0, 4}, {4, 0, 3}, {4, 0, 4}, {3}},
|
||||
};
|
||||
|
||||
// Expected output with label width of 2.0. This would permit three labels at
|
||||
// the first timestep, but only two at the last.
|
||||
std::vector<CTCDecoder::Output> expected_output_width2 = {
|
||||
{{2, 0, 3}, {2, 0, 4}, {3, 0, 3}, {3, 0, 4}, {4, 0, 3}},
|
||||
};
|
||||
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output>
|
||||
expected_output_width2 = {
|
||||
{{2, 0, 3}, {2, 0, 4}, {3, 0, 3}, {3, 0, 4}, {4, 0, 3}},
|
||||
};
|
||||
|
||||
// Convert data containers to the format accepted by the decoder, simply
|
||||
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
|
||||
// using Eigen::Map.
|
||||
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
|
||||
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
|
||||
std::vector<
|
||||
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>>
|
||||
inputs;
|
||||
inputs.reserve(timesteps);
|
||||
for (int t = 0; t < timesteps; ++t) {
|
||||
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
|
||||
}
|
||||
|
||||
// Prepare containers for output and scores.
|
||||
std::vector<CTCDecoder::Output> outputs(top_paths);
|
||||
for (CTCDecoder::Output& output : outputs) {
|
||||
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> outputs(
|
||||
top_paths);
|
||||
for (typename tensorflow::ctc::CTCDecoder<T>::Output& output : outputs) {
|
||||
output.resize(batch_size);
|
||||
}
|
||||
float score[batch_size][top_paths] = {{0.0}};
|
||||
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
|
||||
T score[batch_size][top_paths] = {{0.0}};
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> scores(
|
||||
&score[0][0], batch_size, top_paths);
|
||||
|
||||
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
|
||||
for (int path = 0; path < top_paths; ++path) {
|
||||
@ -314,14 +342,14 @@ TEST(CtcBeamSearch, LabelSelection) {
|
||||
}
|
||||
|
||||
// Try label selection size 2
|
||||
decoder.SetLabelSelectionParameters(2, -1);
|
||||
decoder.SetLabelSelectionParameters(2, T(-1));
|
||||
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
|
||||
for (int path = 0; path < top_paths; ++path) {
|
||||
EXPECT_EQ(outputs[path][0], expected_output_size2[0][path]);
|
||||
}
|
||||
|
||||
// Try label selection width 2.0
|
||||
decoder.SetLabelSelectionParameters(0, 2.0);
|
||||
decoder.SetLabelSelectionParameters(0, T(2.0));
|
||||
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
|
||||
for (int path = 0; path < top_paths; ++path) {
|
||||
EXPECT_EQ(outputs[path][0], expected_output_width2[0][path]);
|
||||
@ -329,18 +357,42 @@ TEST(CtcBeamSearch, LabelSelection) {
|
||||
|
||||
// Try both size 2 and width 2.0: the former is more constraining, so
|
||||
// it's equivalent to that.
|
||||
decoder.SetLabelSelectionParameters(2, 2.0);
|
||||
decoder.SetLabelSelectionParameters(2, T(2.0));
|
||||
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
|
||||
for (int path = 0; path < top_paths; ++path) {
|
||||
EXPECT_EQ(outputs[path][0], expected_output_size2[0][path]);
|
||||
}
|
||||
|
||||
// Size 4 and width > 3.3 are equivalent to no label selection
|
||||
decoder.SetLabelSelectionParameters(4, 3.3001);
|
||||
decoder.SetLabelSelectionParameters(4, T(3.3001));
|
||||
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
|
||||
for (int path = 0; path < top_paths; ++path) {
|
||||
EXPECT_EQ(outputs[path][0], expected_default_output[0][path]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CtcBeamSearch, FloatDecodingWithAndWithoutDictionary) {
|
||||
ctc_beam_search_decoding_with_and_without_dictionary<float>();
|
||||
}
|
||||
|
||||
TEST(CtcBeamSearch, DoubleDecodingWithAndWithoutDictionary) {
|
||||
ctc_beam_search_decoding_with_and_without_dictionary<double>();
|
||||
}
|
||||
|
||||
TEST(CtcBeamSearch, FloatAllBeamElementsHaveFiniteScores) {
|
||||
ctc_beam_search_decoding_all_beam_elements_have_finite_scores<float>();
|
||||
}
|
||||
|
||||
TEST(CtcBeamSearch, DoubleAllBeamElementsHaveFiniteScores) {
|
||||
ctc_beam_search_decoding_all_beam_elements_have_finite_scores<double>();
|
||||
}
|
||||
|
||||
TEST(CtcBeamSearch, FloatLabelSelection) {
|
||||
ctc_beam_search_label_selection<float>();
|
||||
}
|
||||
|
||||
TEST(CtcBeamSearch, DoubleLabelSelection) {
|
||||
ctc_beam_search_label_selection<double>();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -33,12 +33,15 @@ namespace ctc {
|
||||
// The two types of decoding available are:
|
||||
// - greedy path, through the CTCGreedyDecoder
|
||||
// - beam search, through the CTCBeamSearchDecoder
|
||||
template <class T>
|
||||
class CTCDecoder {
|
||||
public:
|
||||
typedef Eigen::Map<const Eigen::ArrayXi> SequenceLength;
|
||||
typedef Eigen::Map<const Eigen::MatrixXf> Input;
|
||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
Input;
|
||||
typedef std::vector<std::vector<int>> Output;
|
||||
typedef Eigen::Map<Eigen::MatrixXf> ScoreOutput;
|
||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||
ScoreOutput;
|
||||
|
||||
CTCDecoder(int num_classes, int batch_size, bool merge_repeated)
|
||||
: num_classes_(num_classes),
|
||||
@ -69,25 +72,27 @@ class CTCDecoder {
|
||||
|
||||
// CTCGreedyDecoder is an implementation of the simple best path decoding
|
||||
// algorithm, selecting at each timestep the most likely class at each timestep.
|
||||
class CTCGreedyDecoder : public CTCDecoder {
|
||||
template <class T>
|
||||
class CTCGreedyDecoder : public CTCDecoder<T> {
|
||||
public:
|
||||
typedef CTCDecoder<T> Decoder;
|
||||
CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated)
|
||||
: CTCDecoder(num_classes, batch_size, merge_repeated) {}
|
||||
: CTCDecoder<T>(num_classes, batch_size, merge_repeated) {}
|
||||
|
||||
Status Decode(const CTCDecoder::SequenceLength& seq_len,
|
||||
const std::vector<CTCDecoder::Input>& input,
|
||||
std::vector<CTCDecoder::Output>* output,
|
||||
CTCDecoder::ScoreOutput* scores) override {
|
||||
if (output->empty() || (*output)[0].size() < batch_size_) {
|
||||
Status Decode(const typename CTCDecoder<T>::SequenceLength& seq_len,
|
||||
const std::vector<typename CTCDecoder<T>::Input>& input,
|
||||
std::vector<typename CTCDecoder<T>::Output>* output,
|
||||
typename CTCDecoder<T>::ScoreOutput* scores) override {
|
||||
if (output->empty() || (*output)[0].size() < Decoder::batch_size_) {
|
||||
return errors::InvalidArgument(
|
||||
"output needs to be of size at least (1, batch_size).");
|
||||
}
|
||||
if (scores->rows() < batch_size_ || scores->cols() == 0) {
|
||||
if (scores->rows() < Decoder::batch_size_ || scores->cols() == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"scores needs to be of size at least (batch_size, 1).");
|
||||
}
|
||||
// For each batch entry, identify the transitions
|
||||
for (int b = 0; b < batch_size_; ++b) {
|
||||
for (int b = 0; b < Decoder::batch_size_; ++b) {
|
||||
int seq_len_b = seq_len[b];
|
||||
// Only writing to beam 0
|
||||
std::vector<int>& output_b = (*output)[0][b];
|
||||
@ -98,8 +103,8 @@ class CTCGreedyDecoder : public CTCDecoder {
|
||||
auto row = input[t].row(b);
|
||||
int max_class_ix;
|
||||
(*scores)(b, 0) += -row.maxCoeff(&max_class_ix);
|
||||
if (max_class_ix != blank_index_ &&
|
||||
!(merge_repeated_ && max_class_ix == prev_class_ix)) {
|
||||
if (max_class_ix != Decoder::blank_index_ &&
|
||||
!(Decoder::merge_repeated_ && max_class_ix == prev_class_ix)) {
|
||||
output_b.push_back(max_class_ix);
|
||||
}
|
||||
prev_class_ix = max_class_ix;
|
||||
|
@ -14,173 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
|
||||
// Starting with t = 0 instead of t = 1 used in the text.
|
||||
// Based on Kanishka's CTC.
|
||||
void CTCLossCalculator::CalculateForwardVariables(
|
||||
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
|
||||
Matrix* log_alpha) const {
|
||||
// Number of cols is the number of time steps = number of cols in target
|
||||
// after the output delay.
|
||||
log_alpha->setConstant(kLogZero);
|
||||
|
||||
int U = l_prime.size();
|
||||
int T = log_alpha->cols();
|
||||
|
||||
CHECK_EQ(U, log_alpha->rows());
|
||||
|
||||
// Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
|
||||
log_alpha->coeffRef(0, 0) = std::log(y(blank_index_, output_delay_));
|
||||
// Below, l_prime[1] == labels[0]
|
||||
auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
|
||||
log_alpha->coeffRef(1, 0) = std::log(y(label_0, output_delay_));
|
||||
|
||||
for (int t = 1; t < T; ++t) {
|
||||
// If there is not enough time to output the remaining labels or
|
||||
// some labels have been skipped, then let log_alpha(u, t) continue to
|
||||
// be kLogZero.
|
||||
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
|
||||
++u) {
|
||||
// Begin (GravesTh) Eq 7.9
|
||||
// Add in the u, t - 1 term.
|
||||
float sum_log_alpha = kLogZero;
|
||||
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
|
||||
sum_log_alpha = log_alpha->coeff(u, t - 1);
|
||||
}
|
||||
|
||||
// Add in the u - 1, t - 1 term.
|
||||
if (u > 0) {
|
||||
sum_log_alpha =
|
||||
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
|
||||
}
|
||||
|
||||
// Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
|
||||
if (u > 1) {
|
||||
const bool matching_labels_merge =
|
||||
ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
|
||||
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
|
||||
sum_log_alpha =
|
||||
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
|
||||
}
|
||||
}
|
||||
// Multiply the summed alphas with the activation log probability.
|
||||
log_alpha->coeffRef(u, t) =
|
||||
std::log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
|
||||
} // End (GravesTh) Eq 7.9.
|
||||
}
|
||||
}
|
||||
|
||||
// Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
|
||||
void CTCLossCalculator::CalculateBackwardVariables(
|
||||
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
|
||||
Matrix* log_beta) const {
|
||||
// Number of cols is the number of time steps = number of cols in target.
|
||||
// Matrix log_beta =
|
||||
// Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
|
||||
// kLogZero);
|
||||
log_beta->setConstant(kLogZero);
|
||||
int T = log_beta->cols();
|
||||
int U = l_prime.size();
|
||||
CHECK_EQ(U, log_beta->rows());
|
||||
|
||||
// Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
|
||||
for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
|
||||
|
||||
for (int t = T - 1 - 1; t >= 0; --t) {
|
||||
// If there is not enough time to output the remaining labels or
|
||||
// some labels have been skipped, then let log_beta(u, t) continue to
|
||||
// be kLogZero.
|
||||
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
|
||||
++u) {
|
||||
// Begin (GravesTh) Eq 7.15
|
||||
// Add in the u, t + 1 term.
|
||||
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
|
||||
log_beta->coeffRef(u, t) =
|
||||
LogSumExp(log_beta->coeff(u, t),
|
||||
log_beta->coeff(u, t + 1) +
|
||||
std::log(y(l_prime[u], output_delay_ + t + 1)));
|
||||
}
|
||||
|
||||
// Add in the u + 1, t + 1 term.
|
||||
if (u + 1 < U) {
|
||||
log_beta->coeffRef(u, t) =
|
||||
LogSumExp(log_beta->coeff(u, t),
|
||||
log_beta->coeff(u + 1, t + 1) +
|
||||
std::log(y(l_prime[u + 1], output_delay_ + t + 1)));
|
||||
}
|
||||
|
||||
// Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
|
||||
if (u + 2 < U) {
|
||||
const bool matching_labels_merge =
|
||||
ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
|
||||
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
|
||||
// Add in u + 2 term.
|
||||
log_beta->coeffRef(u, t) =
|
||||
LogSumExp(log_beta->coeff(u, t),
|
||||
log_beta->coeff(u + 2, t + 1) +
|
||||
std::log(y(l_prime[u + 2], output_delay_ + t + 1)));
|
||||
}
|
||||
} // End (GravesTh) Eq. 7.15
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Using (GravesTh) Eq 7.26 & 7.34.
|
||||
void CTCLossCalculator::CalculateGradient(const std::vector<int>& l_prime,
|
||||
const Matrix& y,
|
||||
const Matrix& log_alpha,
|
||||
const Matrix& log_beta,
|
||||
float log_p_z_x, Matrix* dy) const {
|
||||
// Only working with the leftmost part of dy for this batch element.
|
||||
auto dy_b = dy->leftCols(y.cols());
|
||||
|
||||
// It is possible that no valid path is found if the activations for the
|
||||
// targets are zero.
|
||||
if (log_p_z_x == kLogZero) {
|
||||
LOG(WARNING) << "No valid path found.";
|
||||
dy_b = y;
|
||||
return;
|
||||
}
|
||||
|
||||
int L = y.rows();
|
||||
int T = y.cols();
|
||||
int U = l_prime.size();
|
||||
|
||||
for (int t = 0; t < T - output_delay_; ++t) {
|
||||
Array prob_sum(L);
|
||||
prob_sum.setConstant(kLogZero);
|
||||
|
||||
for (int u = 0; u < U; ++u) {
|
||||
int l = l_prime[u];
|
||||
prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
|
||||
}
|
||||
|
||||
for (int l = 0; l < L; ++l) {
|
||||
// Negative term in (GravesTh) Eq 7.28.
|
||||
float negative_term = expf(prob_sum[l] - log_p_z_x);
|
||||
|
||||
dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CTCLossCalculator::GetLPrimeIndices(const std::vector<int>& l,
|
||||
std::vector<int>* l_prime) const {
|
||||
// Assumption is that l_prime is empty.
|
||||
l_prime->reserve(2 * l.size() + 1);
|
||||
|
||||
for (auto label : l) {
|
||||
l_prime->push_back(blank_index_);
|
||||
l_prime->push_back(label);
|
||||
}
|
||||
// Add final blank to l'.
|
||||
l_prime->push_back(blank_index_);
|
||||
}
|
||||
|
||||
} // namespace ctc
|
||||
namespace ctc {} // namespace ctc
|
||||
} // namespace tensorflow
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
template <class T>
|
||||
class CTCLossCalculator {
|
||||
// Connectionist Temporal Classification Loss
|
||||
//
|
||||
@ -50,10 +51,14 @@ class CTCLossCalculator {
|
||||
// Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen.
|
||||
public:
|
||||
typedef std::vector<std::vector<int>> LabelSequences;
|
||||
typedef Eigen::MatrixXf Matrix;
|
||||
typedef Eigen::ArrayXf Array;
|
||||
typedef Eigen::Map<const Eigen::MatrixXf> InputMap;
|
||||
typedef Eigen::Map<Eigen::MatrixXf> OutputMap;
|
||||
using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
|
||||
// typedef Eigen::MatrixXd Matrix;
|
||||
using Array = Eigen::Array<T, Eigen::Dynamic, 1>;
|
||||
// typedef Eigen::ArrayXd Array;
|
||||
using InputMap = Eigen::Map<const Matrix>;
|
||||
// typedef Eigen::Map<const Eigen::MatrixXd> InputMap;
|
||||
using OutputMap = Eigen::Map<Matrix>;
|
||||
// typedef Eigen::Map<Eigen::MatrixXd> OutputMap;
|
||||
|
||||
CTCLossCalculator(int blank_index, int output_delay)
|
||||
: blank_index_(blank_index), output_delay_(output_delay) {}
|
||||
@ -79,7 +84,7 @@ class CTCLossCalculator {
|
||||
|
||||
void CalculateGradient(const std::vector<int>& l_prime, const Matrix& y,
|
||||
const Matrix& log_alpha, const Matrix& log_beta,
|
||||
float log_p_z_x, Matrix* dy) const;
|
||||
T log_p_z_x, Matrix* dy) const;
|
||||
|
||||
void GetLPrimeIndices(const std::vector<int>& l,
|
||||
std::vector<int>* l_prime) const;
|
||||
@ -103,9 +108,10 @@ class CTCLossCalculator {
|
||||
const int output_delay_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
template <typename VectorIn, typename VectorOut, typename MatrixIn,
|
||||
typename MatrixOut>
|
||||
Status CTCLossCalculator::CalculateLoss(
|
||||
Status CTCLossCalculator<T>::CalculateLoss(
|
||||
const VectorIn& seq_len, const LabelSequences& labels,
|
||||
const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated,
|
||||
bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs,
|
||||
@ -205,11 +211,11 @@ Status CTCLossCalculator::CalculateLoss(
|
||||
// Convert label from DistBelief
|
||||
// y, prob are in num_classes x seq_len(b)
|
||||
// Output activations.
|
||||
Eigen::ArrayXf y_b_col;
|
||||
Array y_b_col;
|
||||
for (int t = 0; t < seq_len(b); t++) {
|
||||
// Calculate the softmax of y_b. Use double precision
|
||||
// Calculate the softmax of y_b. Use original precision
|
||||
// arithmetic for the sum.
|
||||
float max_coeff = inputs[t].row(b).maxCoeff();
|
||||
T max_coeff = inputs[t].row(b).maxCoeff();
|
||||
y_b_col = (inputs[t].row(b).array() - max_coeff).exp();
|
||||
y_b.col(t) = y_b_col / y_b_col.sum();
|
||||
}
|
||||
@ -222,7 +228,7 @@ Status CTCLossCalculator::CalculateLoss(
|
||||
|
||||
// The loss is computed as the log(p(z|x)) between the target and
|
||||
// prediction. Do lazy evaluation of log_prob here.
|
||||
float log_p_z_x = kLogZero;
|
||||
T log_p_z_x = kLogZero<T>::val;
|
||||
for (int u = 0; u < l_prime.size(); ++u) {
|
||||
// (GravesTh) Eq 7.26, sum over all paths for t = 0.
|
||||
log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
|
||||
@ -253,19 +259,19 @@ Status CTCLossCalculator::CalculateLoss(
|
||||
// fwd,bwd: T * 2 * (2*L + 1) * (Cost(LogSumExp) + Cost(Log)) +
|
||||
// grad: T * ((2L + 1) * Cost(LogSumExp) + L * (Cost(Expf) + Cost(Add)).
|
||||
const int64 cost_exp = Eigen::internal::functor_traits<
|
||||
Eigen::internal::scalar_exp_op<float>>::Cost;
|
||||
Eigen::internal::scalar_exp_op<T>>::Cost;
|
||||
const int64 cost_log = Eigen::internal::functor_traits<
|
||||
Eigen::internal::scalar_log_op<float>>::Cost;
|
||||
Eigen::internal::scalar_log_op<T>>::Cost;
|
||||
const int64 cost_log_sum_exp =
|
||||
Eigen::TensorOpCost::AddCost<float>() + cost_exp + cost_log;
|
||||
Eigen::TensorOpCost::AddCost<T>() + cost_exp + cost_log;
|
||||
const int64 cost =
|
||||
max_seq_len * num_classes *
|
||||
(cost_exp + Eigen::TensorOpCost::DivCost<float>()) +
|
||||
(cost_exp + Eigen::TensorOpCost::DivCost<T>()) +
|
||||
max_seq_len * 2 * (2 * num_classes + 1) *
|
||||
(cost_log_sum_exp + cost_log) +
|
||||
max_seq_len *
|
||||
((2 * num_classes + 1) * cost_log_sum_exp +
|
||||
num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<float>()));
|
||||
num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<T>()));
|
||||
Shard(workers->num_threads, workers->workers, batch_size, cost,
|
||||
ComputeLossAndGradients);
|
||||
} else {
|
||||
@ -274,8 +280,9 @@ Status CTCLossCalculator::CalculateLoss(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <typename Vector>
|
||||
Status CTCLossCalculator::PopulateLPrimes(
|
||||
Status CTCLossCalculator<T>::PopulateLPrimes(
|
||||
bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs,
|
||||
int batch_size, int num_classes, const Vector& seq_len,
|
||||
const LabelSequences& labels, size_t* max_u_prime,
|
||||
@ -357,6 +364,173 @@ Status CTCLossCalculator::PopulateLPrimes(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
|
||||
// Starting with t = 0 instead of t = 1 used in the text.
|
||||
// Based on Kanishka's CTC.
|
||||
template <typename TT>
|
||||
void CTCLossCalculator<TT>::CalculateForwardVariables(
|
||||
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
|
||||
Matrix* log_alpha) const {
|
||||
// Number of cols is the number of time steps = number of cols in target
|
||||
// after the output delay.
|
||||
log_alpha->setConstant(kLogZero<TT>::val);
|
||||
|
||||
int U = l_prime.size();
|
||||
int T = log_alpha->cols();
|
||||
|
||||
CHECK_EQ(U, log_alpha->rows());
|
||||
|
||||
// Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
|
||||
log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_));
|
||||
// Below, l_prime[1] == labels[0]
|
||||
auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
|
||||
log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_));
|
||||
|
||||
for (int t = 1; t < T; ++t) {
|
||||
// If there is not enough time to output the remaining labels or
|
||||
// some labels have been skipped, then let log_alpha(u, t) continue to
|
||||
// be kLogZero.
|
||||
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
|
||||
++u) {
|
||||
// Begin (GravesTh) Eq 7.9
|
||||
// Add in the u, t - 1 term.
|
||||
auto sum_log_alpha = kLogZero<TT>::val;
|
||||
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
|
||||
sum_log_alpha = log_alpha->coeff(u, t - 1);
|
||||
}
|
||||
|
||||
// Add in the u - 1, t - 1 term.
|
||||
if (u > 0) {
|
||||
sum_log_alpha =
|
||||
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
|
||||
}
|
||||
|
||||
// Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
|
||||
if (u > 1) {
|
||||
const bool matching_labels_merge =
|
||||
ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
|
||||
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
|
||||
sum_log_alpha =
|
||||
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
|
||||
}
|
||||
}
|
||||
// Multiply the summed alphas with the activation log probability.
|
||||
log_alpha->coeffRef(u, t) =
|
||||
log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
|
||||
} // End (GravesTh) Eq 7.9.
|
||||
}
|
||||
}
|
||||
|
||||
// Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
|
||||
template <class TT>
|
||||
void CTCLossCalculator<TT>::CalculateBackwardVariables(
|
||||
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
|
||||
Matrix* log_beta) const {
|
||||
// Number of cols is the number of time steps = number of cols in target.
|
||||
// Matrix log_beta =
|
||||
// Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
|
||||
// kLogZero);
|
||||
log_beta->setConstant(kLogZero<TT>::val);
|
||||
int T = log_beta->cols();
|
||||
int U = l_prime.size();
|
||||
CHECK_EQ(U, log_beta->rows());
|
||||
|
||||
// Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
|
||||
for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
|
||||
|
||||
for (int t = T - 1 - 1; t >= 0; --t) {
|
||||
// If there is not enough time to output the remaining labels or
|
||||
// some labels have been skipped, then let log_beta(u, t) continue to
|
||||
// be kLogZero.
|
||||
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
|
||||
++u) {
|
||||
// Begin (GravesTh) Eq 7.15
|
||||
// Add in the u, t + 1 term.
|
||||
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
|
||||
log_beta->coeffRef(u, t) =
|
||||
LogSumExp(log_beta->coeff(u, t),
|
||||
log_beta->coeff(u, t + 1) +
|
||||
log(y(l_prime[u], output_delay_ + t + 1)));
|
||||
}
|
||||
|
||||
// Add in the u + 1, t + 1 term.
|
||||
if (u + 1 < U) {
|
||||
log_beta->coeffRef(u, t) =
|
||||
LogSumExp(log_beta->coeff(u, t),
|
||||
log_beta->coeff(u + 1, t + 1) +
|
||||
log(y(l_prime[u + 1], output_delay_ + t + 1)));
|
||||
}
|
||||
|
||||
// Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
|
||||
if (u + 2 < U) {
|
||||
const bool matching_labels_merge =
|
||||
ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
|
||||
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
|
||||
// Add in u + 2 term.
|
||||
log_beta->coeffRef(u, t) =
|
||||
LogSumExp(log_beta->coeff(u, t),
|
||||
log_beta->coeff(u + 2, t + 1) +
|
||||
log(y(l_prime[u + 2], output_delay_ + t + 1)));
|
||||
}
|
||||
} // End (GravesTh) Eq. 7.15
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Using (GravesTh) Eq 7.26 & 7.34.
|
||||
template <typename TT>
|
||||
void CTCLossCalculator<TT>::CalculateGradient(const std::vector<int>& l_prime,
|
||||
const Matrix& y,
|
||||
const Matrix& log_alpha,
|
||||
const Matrix& log_beta,
|
||||
TT log_p_z_x, Matrix* dy) const {
|
||||
// Only working with the leftmost part of dy for this batch element.
|
||||
auto dy_b = dy->leftCols(y.cols());
|
||||
|
||||
// It is possible that no valid path is found if the activations for the
|
||||
// targets are zero.
|
||||
if (log_p_z_x == kLogZero<TT>::val) {
|
||||
LOG(WARNING) << "No valid path found.";
|
||||
dy_b = y;
|
||||
return;
|
||||
}
|
||||
|
||||
int L = y.rows();
|
||||
int T = y.cols();
|
||||
int U = l_prime.size();
|
||||
|
||||
for (int t = 0; t < T - output_delay_; ++t) {
|
||||
Array prob_sum(L);
|
||||
prob_sum.setConstant(kLogZero<TT>::val);
|
||||
|
||||
for (int u = 0; u < U; ++u) {
|
||||
int l = l_prime[u];
|
||||
prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
|
||||
}
|
||||
|
||||
for (int l = 0; l < L; ++l) {
|
||||
// Negative term in (GravesTh) Eq 7.28.
|
||||
auto negative_term = expf(prob_sum[l] - log_p_z_x);
|
||||
|
||||
dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class TT>
|
||||
void CTCLossCalculator<TT>::GetLPrimeIndices(const std::vector<int>& l,
|
||||
std::vector<int>* l_prime) const {
|
||||
// Assumption is that l_prime is empty.
|
||||
l_prime->reserve(2 * l.size() + 1);
|
||||
|
||||
for (auto label : l) {
|
||||
l_prime->push_back(blank_index_);
|
||||
l_prime->push_back(label);
|
||||
}
|
||||
// Add final blank to l'.
|
||||
l_prime->push_back(blank_index_);
|
||||
}
|
||||
|
||||
} // namespace ctc
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -23,18 +23,23 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
const float kLogZero = -std::numeric_limits<float>::infinity();
|
||||
template <class T>
|
||||
struct kLogZero {
|
||||
static constexpr T val = -std::numeric_limits<T>::infinity(); // NOLINT
|
||||
};
|
||||
|
||||
// Add logarithmic probabilities using:
|
||||
// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a)))
|
||||
// The two inputs are assumed to be log probabilities.
|
||||
// (GravesTh) Eq. 7.18
|
||||
inline float LogSumExp(float log_prob_1, float log_prob_2) {
|
||||
template <typename T>
|
||||
inline T LogSumExp(T log_prob_1, T log_prob_2) {
|
||||
// const T kLogZero = -std::numeric_limits<T>::infinity();
|
||||
// Always have 'b' be the smaller number to avoid the exponential from
|
||||
// blowing up.
|
||||
if (log_prob_1 == kLogZero) {
|
||||
if (log_prob_1 == kLogZero<T>::val) {
|
||||
return log_prob_2;
|
||||
} else if (log_prob_2 == kLogZero) {
|
||||
} else if (log_prob_2 == kLogZero<T>::val) {
|
||||
return log_prob_1;
|
||||
} else {
|
||||
return (log_prob_1 > log_prob_2)
|
||||
|
Loading…
Reference in New Issue
Block a user