Add support for CTC for float64

This commit is contained in:
Andrii Prymostka 2018-07-02 16:02:22 +03:00
parent bd33b05708
commit 2aabaa4b27
13 changed files with 751 additions and 346 deletions

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -33,11 +34,11 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
int* c) {
template<typename T>
inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r, int* c) {
*c = 0;
CHECK_LT(0, m.dimension(1));
float p = m(r, 0);
auto p = m(r, 0);
for (int i = 1; i < m.dimension(1); ++i) {
if (m(r, i) > p) {
p = m(r, i);
@ -170,6 +171,7 @@ class CTCDecodeHelper {
TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};
template<typename T>
class CTCGreedyDecoderOp : public OpKernel {
public:
explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -189,7 +191,7 @@ class CTCGreedyDecoderOp : public OpKernel {
const TensorShape& inputs_shape = inputs->shape();
std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
const int64 max_time = inputs_shape.dim_size(0);
const int64 batch_size = inputs_shape.dim_size(1);
const int64 num_classes_raw = inputs_shape.dim_size(2);
@ -198,14 +200,14 @@ class CTCGreedyDecoderOp : public OpKernel {
errors::InvalidArgument("num_classes cannot exceed max int"));
const int num_classes = static_cast<const int>(num_classes_raw);
auto inputs_t = inputs->tensor<float, 3>();
auto inputs_t = inputs->tensor<T, 3>();
for (std::size_t t = 0; t < max_time; ++t) {
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
batch_size, num_classes);
}
auto seq_len_t = seq_len->vec<int32>();
auto log_prob_t = log_prob->matrix<float>();
auto log_prob_t = log_prob->matrix<T>();
log_prob_t.setZero();
@ -221,7 +223,7 @@ class CTCGreedyDecoderOp : public OpKernel {
int prev_indices = -1;
for (int t = 0; t < seq_len_t(b); ++t) {
int max_class_indices;
log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
log_prob_t(b, 0) += -RowMax<T>(input_list_t[t], b, &max_class_indices);
if (max_class_indices != blank_index &&
!(merge_repeated_ && max_class_indices == prev_indices)) {
sequence.push_back(max_class_indices);
@ -250,10 +252,18 @@ class CTCGreedyDecoderOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
};
REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU),
CTCGreedyDecoderOp);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CTCGreedyDecoderOp<T>);
REGISTER_CPU(float);
REGISTER_CPU(double);
#undef REGISTER_CPU
// CTC beam search
template<typename T>
class CTCBeamSearchDecoderOp : public OpKernel {
public:
explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -275,9 +285,9 @@ class CTCBeamSearchDecoderOp : public OpKernel {
ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
&decoded_values, &decoded_shape));
auto inputs_t = inputs->tensor<float, 3>();
auto inputs_t = inputs->tensor<T, 3>();
auto seq_len_t = seq_len->vec<int32>();
auto log_prob_t = log_prob->matrix<float>();
auto log_prob_t = log_prob->matrix<T>();
const TensorShape& inputs_shape = inputs->shape();
@ -291,21 +301,21 @@ class CTCBeamSearchDecoderOp : public OpKernel {
log_prob_t.setZero();
std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
for (std::size_t t = 0; t < max_time; ++t) {
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
batch_size, num_classes);
}
ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
&beam_scorer_, 1 /* batch_size */,
merge_repeated_);
Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
auto input_chip_t = input_chip.flat<float>();
ctc::CTCBeamSearchDecoder<T> beam_search(num_classes, beam_width_,
&beam_scorer_, 1 /* batch_size */,
merge_repeated_);
Tensor input_chip(DataTypeToEnum<T>::v(), TensorShape({num_classes}));
auto input_chip_t = input_chip.flat<T>();
std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
std::vector<float> log_probs;
std::vector<T> log_probs;
// Assumption: the blank index is num_classes - 1
for (int b = 0; b < batch_size; ++b) {
@ -313,8 +323,8 @@ class CTCBeamSearchDecoderOp : public OpKernel {
best_paths_b.resize(decode_helper_.GetTopPaths());
for (int t = 0; t < seq_len_t(b); ++t) {
input_chip_t = input_list_t[t].chip(b, 0);
auto input_bi =
Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
auto input_bi = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>
(input_chip_t.data(), num_classes);
beam_search.Step(input_bi);
}
OP_REQUIRES_OK(
@ -335,13 +345,20 @@ class CTCBeamSearchDecoderOp : public OpKernel {
private:
CTCDecodeHelper decode_helper_;
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
typename ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer beam_scorer_;
bool merge_repeated_;
int beam_width_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp);
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp<T>);
};
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU),
CTCBeamSearchDecoderOp);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CTCBeamSearchDecoderOp<T>);
REGISTER_CPU(float);
REGISTER_CPU(double);
#undef REGISTER_CPU
} // end namespace tensorflow

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
@ -26,14 +27,15 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
// typedef Eigen::ThreadPoolDevice CPUDevice;
template<typename T>
class CTCLossOp : public OpKernel {
typedef Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic,
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor> >
InputMap;
typedef Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
OutputMap;
public:
@ -110,7 +112,7 @@ class CTCLossOp : public OpKernel {
errors::InvalidArgument("label SparseTensor is not valid: ",
labels_sp_valid.error_message()));
ctc::CTCLossCalculator::LabelSequences labels_t(batch_size);
typename ctc::CTCLossCalculator<T>::LabelSequences labels_t(batch_size);
for (const auto& g : labels_sp.group({0})) { // iterate by batch
const int64 batch_indices = g.group()[0];
OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size),
@ -137,13 +139,13 @@ class CTCLossOp : public OpKernel {
Tensor* loss = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
auto loss_t = loss->vec<float>();
auto loss_t = loss->vec<T>();
Tensor* gradient;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("gradient", inputs_shape, &gradient));
auto gradient_t = gradient->tensor<float, 3>();
auto inputs_t = inputs->tensor<float, 3>();
auto gradient_t = gradient->tensor<T, 3>();
auto inputs_t = inputs->tensor<T, 3>();
std::vector<OutputMap> gradient_list_t;
std::vector<InputMap> input_list_t;
@ -158,7 +160,7 @@ class CTCLossOp : public OpKernel {
gradient_t.setZero();
// Assumption: the blank index is num_classes - 1
ctc::CTCLossCalculator ctc_loss_calculator(num_classes - 1, 0);
ctc::CTCLossCalculator<T> ctc_loss_calculator(num_classes - 1, 0);
DeviceBase::CpuWorkerThreads workers =
*ctx->device()->tensorflow_cpu_worker_threads();
OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
@ -173,9 +175,17 @@ class CTCLossOp : public OpKernel {
bool ctc_merge_repeated_;
bool ignore_longer_outputs_than_inputs_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp);
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp<T>);
};
REGISTER_KERNEL_BUILDER(Name("CTCLoss").Device(DEVICE_CPU), CTCLossOp);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("CTCLoss").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CTCLossOp<T>);
REGISTER_CPU(float);
REGISTER_CPU(double);
#undef REGISTER_CPU
} // end namespace tensorflow

View File

@ -25,15 +25,16 @@ using shape_inference::ShapeHandle;
// CTC is Connectionist Temporal Classification. See util/ctc/ for details.
REGISTER_OP("CTCLoss")
.Input("inputs: float")
.Input("inputs: T")
.Input("labels_indices: int64")
.Input("labels_values: int32")
.Input("sequence_length: int32")
.Attr("preprocess_collapse_repeated: bool = false")
.Attr("ctc_merge_repeated: bool = true")
.Attr("ignore_longer_outputs_than_inputs: bool = false")
.Output("loss: float")
.Output("gradient: float")
.Output("loss: T")
.Output("gradient: T")
.Attr("T: {float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle labels_indices;
@ -62,13 +63,14 @@ REGISTER_OP("CTCLoss")
});
REGISTER_OP("CTCGreedyDecoder")
.Input("inputs: float")
.Input("inputs: T")
.Input("sequence_length: int32")
.Attr("merge_repeated: bool = false")
.Output("decoded_indices: int64")
.Output("decoded_values: int64")
.Output("decoded_shape: int64")
.Output("log_probability: float")
.Output("log_probability: T")
.Attr("T: {float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle sequence_length;
@ -90,7 +92,7 @@ REGISTER_OP("CTCGreedyDecoder")
});
REGISTER_OP("CTCBeamSearchDecoder")
.Input("inputs: float")
.Input("inputs: T")
.Input("sequence_length: int32")
.Attr("beam_width: int >= 1")
.Attr("top_paths: int >= 1")
@ -98,7 +100,8 @@ REGISTER_OP("CTCBeamSearchDecoder")
.Output("decoded_indices: top_paths * int64")
.Output("decoded_values: top_paths * int64")
.Output("decoded_shape: top_paths * int64")
.Output("log_probability: float")
.Output("log_probability: T")
.Attr("T: {float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle sequence_length;

View File

@ -59,7 +59,8 @@ tf_cc_tests(
name = "ctc_beam_search_test",
size = "small",
srcs = [
"ctc_beam_search_test.cc",
"ctc_beam_search_test_float.cc",
"ctc_beam_search_test_double.cc",
],
deps = [
":ctc_beam_search_lib",

View File

@ -41,30 +41,33 @@ namespace ctc_beam_search {
struct EmptyBeamState {};
template <typename T>
struct BeamProbability {
BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {}
BeamProbability() : total(kLogZero<T>::val),
blank(kLogZero<T>::val),
label(kLogZero<T>::val) {}
void Reset() {
total = kLogZero;
blank = kLogZero;
label = kLogZero;
total = kLogZero<T>::val;
blank = kLogZero<T>::val;
label = kLogZero<T>::val;
}
float total;
float blank;
float label;
T total;
T blank;
T label;
};
template <class CTCBeamState>
template <class T, class CTCBeamState>
class BeamRoot;
template <class CTCBeamState = EmptyBeamState>
template <class T, class CTCBeamState = EmptyBeamState>
struct BeamEntry {
// BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
BeamEntry<CTCBeamState>* p, int l);
inline bool Active() const { return newp.total != kLogZero; }
friend BeamEntry<T, CTCBeamState>*
BeamRoot<T, CTCBeamState>::AddEntry(BeamEntry<T, CTCBeamState>* p, int l);
inline bool Active() const { return newp.total != kLogZero<T>::val; }
// Return the child at the given index, or construct a new one in-place if
// none was found.
BeamEntry& GetChild(int ind) {
BeamEntry<T>& GetChild(int ind) {
auto entry = children.emplace(ind, nullptr);
auto& child_entry = entry.first->second;
// If this is a new child, populate the BeamEntry<CTCBeamState>*.
@ -76,7 +79,7 @@ struct BeamEntry {
std::vector<int> LabelSeq(bool merge_repeated) const {
std::vector<int> labels;
int prev_label = -1;
const BeamEntry* c = this;
const BeamEntry<T>* c = this;
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
if (!merge_repeated || c->label != prev_label) {
labels.push_back(c->label);
@ -88,12 +91,12 @@ struct BeamEntry {
return labels;
}
BeamEntry<CTCBeamState>* parent;
BeamEntry<T, CTCBeamState>* parent;
int label;
// All instances of child BeamEntry are owned by *beam_root.
gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children;
BeamProbability oldp;
BeamProbability newp;
gtl::FlatMap<int, BeamEntry<T, CTCBeamState>*> children;
BeamProbability<T> oldp;
BeamProbability<T> newp;
CTCBeamState state;
private:
@ -102,40 +105,40 @@ struct BeamEntry {
// otherwise parent will become invalid.
// This private constructor is only called through the factory method
// BeamRoot<CTCBeamState>::AddEntry().
BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
BeamEntry(BeamEntry* p, int l, BeamRoot<T, CTCBeamState>* beam_root)
: parent(p), label(l), beam_root(beam_root) {}
BeamRoot<CTCBeamState>* beam_root;
BeamRoot<T, CTCBeamState>* beam_root;
TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
};
// This class owns all instances of BeamEntry. This is used to avoid recursive
// destructor call during destruction.
template <class CTCBeamState = EmptyBeamState>
template <class T, class CTCBeamState = EmptyBeamState>
class BeamRoot {
public:
BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
BeamRoot(BeamEntry<T, CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
BeamRoot(const BeamRoot&) = delete;
BeamRoot& operator=(const BeamRoot&) = delete;
BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
BeamEntry<T, CTCBeamState>* AddEntry(BeamEntry<T, CTCBeamState>* p, int l) {
auto* new_entry = new BeamEntry<T, CTCBeamState>(p, l, this);
beam_entries_.emplace_back(new_entry);
return new_entry;
}
BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
BeamEntry<T, CTCBeamState>* RootEntry() const { return root_entry_; }
private:
BeamEntry<CTCBeamState>* root_entry_ = nullptr;
std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
BeamEntry<T, CTCBeamState>* root_entry_ = nullptr;
std::vector<std::unique_ptr<BeamEntry<T, CTCBeamState>>> beam_entries_;
};
// BeamComparer is the default beam comparer provided in CTCBeamSearch.
template <class CTCBeamState = EmptyBeamState>
template <class T, class CTCBeamState = EmptyBeamState>
class BeamComparer {
public:
virtual ~BeamComparer() {}
virtual bool inline operator()(const BeamEntry<CTCBeamState>* a,
const BeamEntry<CTCBeamState>* b) const {
virtual bool inline operator()(const BeamEntry<T, CTCBeamState>* a,
const BeamEntry<T, CTCBeamState>* b) const {
return a->newp.total > b->newp.total;
}
};

View File

@ -34,7 +34,7 @@ namespace ctc {
// be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex
// scoring is required. Its main purpose is to provide a thin layer for
// integrating language model scoring easily.
template <typename CTCBeamState>
template <typename T, typename CTCBeamState>
class BaseBeamScorer {
public:
virtual ~BaseBeamScorer() {}
@ -56,8 +56,8 @@ class BaseBeamScorer {
//
// The score returned should be a log-probability. In the simplest case, as
// there's no state expansion logic, the expansion score is zero.
virtual float GetStateExpansionScore(const CTCBeamState& state,
float previous_score) const {
virtual T GetStateExpansionScore(const CTCBeamState& state,
T previous_score) const {
return previous_score;
}
// GetStateEndExpansionScore should be an inexpensive method to retrieve the
@ -65,8 +65,8 @@ class BaseBeamScorer {
// multiplied (log-addition) with the final probability of the beam.
//
// The score returned should be a log-probability.
virtual float GetStateEndExpansionScore(const CTCBeamState& state) const {
return 0;
virtual T GetStateEndExpansionScore(const CTCBeamState& state) const {
return T(0);
}
};

View File

@ -38,10 +38,11 @@ limitations under the License.
namespace tensorflow {
namespace ctc {
template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
template <typename T,
typename CTCBeamState = ctc_beam_search::EmptyBeamState,
typename CTCBeamComparer =
ctc_beam_search::BeamComparer<CTCBeamState>>
class CTCBeamSearchDecoder : public CTCDecoder {
ctc_beam_search::BeamComparer<T, CTCBeamState>>
class CTCBeamSearchDecoder : public CTCDecoder<T> {
// Beam Search
//
// Example (GravesTh Fig. 7.5):
@ -73,12 +74,12 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// starts at 0). This special case can be calculated as:
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
// but we calculate it recursively for speed purposes.
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
typedef ctc_beam_search::BeamProbability BeamProbability;
typedef ctc_beam_search::BeamEntry<T, CTCBeamState> BeamEntry;
typedef ctc_beam_search::BeamRoot<T, CTCBeamState> BeamRoot;
typedef ctc_beam_search::BeamProbability<T> BeamProbability;
public:
typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
typedef BaseBeamScorer<T, CTCBeamState> DefaultBeamScorer;
// The beam search decoder is constructed specifying the beam_width (number of
// candidates to keep at each decoding timestep) and a beam scorer (used for
@ -87,9 +88,10 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
// standard beam search.
CTCBeamSearchDecoder(int num_classes, int beam_width,
BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
BaseBeamScorer<T, CTCBeamState>* scorer,
int batch_size = 1,
bool merge_repeated = false)
: CTCDecoder(num_classes, batch_size, merge_repeated),
: CTCDecoder<T>(num_classes, batch_size, merge_repeated),
beam_width_(beam_width),
leaves_(beam_width),
beam_scorer_(CHECK_NOTNULL(scorer)) {
@ -99,27 +101,27 @@ class CTCBeamSearchDecoder : public CTCDecoder {
~CTCBeamSearchDecoder() override {}
// Run the hibernating beam search algorithm on the given input.
Status Decode(const CTCDecoder::SequenceLength& seq_len,
const std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output,
CTCDecoder::ScoreOutput* scores) override;
Status Decode(const typename CTCDecoder<T>::SequenceLength& seq_len,
const std::vector<typename CTCDecoder<T>::Input>& input,
std::vector<typename CTCDecoder<T>::Output>* output,
typename CTCDecoder<T>::ScoreOutput* scores) override;
// Calculate the next step of the beam search and update the internal state.
template <typename Vector>
void Step(const Vector& log_input_t);
template <typename Vector>
float GetTopK(const int K, const Vector& input,
std::vector<float>* top_k_logits,
std::vector<int>* top_k_indices);
T GetTopK(const int K, const Vector& input,
std::vector<T>* top_k_logits,
std::vector<int>* top_k_indices);
// Retrieve the beam scorer instance used during decoding.
BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
BaseBeamScorer<T, CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
// Set label selection parameters for faster decoding.
// See comments for label_selection_size_ and label_selection_margin_.
void SetLabelSelectionParameters(int label_selection_size,
float label_selection_margin) {
T label_selection_margin) {
label_selection_size_ = label_selection_size;
label_selection_margin_ = label_selection_margin;
}
@ -129,7 +131,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// Extract the top n paths at current time step
Status TopPaths(int n, std::vector<std::vector<int>>* paths,
std::vector<float>* log_probs, bool merge_repeated) const;
std::vector<T>* log_probs, bool merge_repeated) const;
private:
int beam_width_;
@ -145,37 +147,38 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// Default is to do no label selection.
// For more detail: https://research.google.com/pubs/pub44823.html
int label_selection_size_ = 0; // zero means unlimited
float label_selection_margin_ = -1; // -1 means unlimited.
T label_selection_margin_ = -1; // -1 means unlimited.
gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
std::unique_ptr<BeamRoot> beam_root_;
BaseBeamScorer<CTCBeamState>* beam_scorer_;
BaseBeamScorer<T, CTCBeamState>* beam_scorer_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
};
template <typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
const CTCDecoder::SequenceLength& seq_len,
const std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Decode(
const typename CTCDecoder<T>::SequenceLength& seq_len,
const std::vector<typename CTCDecoder<T>::Input>& input,
std::vector<typename CTCDecoder<T>::Output>* output,
typename CTCDecoder<T>::ScoreOutput* scores) {
// Storage for top paths.
std::vector<std::vector<int>> beams;
std::vector<float> beam_log_probabilities;
std::vector<T> beam_log_probabilities;
int top_n = output->size();
if (std::any_of(output->begin(), output->end(),
[this](const CTCDecoder::Output& output) -> bool {
[this](const typename CTCDecoder<T>::Output& output) -> bool {
return output.size() < this->batch_size_;
})) {
return errors::InvalidArgument(
"output needs to be of size at least (top_n, batch_size).");
}
if (scores->rows() < batch_size_ || scores->cols() < top_n) {
if (scores->rows() < this->batch_size_ || scores->cols() < top_n) {
return errors::InvalidArgument(
"scores needs to be of size at least (batch_size, top_n).");
}
for (int b = 0; b < batch_size_; ++b) {
for (int b = 0; b < this->batch_size_; ++b) {
int seq_len_b = seq_len[b];
Reset();
@ -196,7 +199,7 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
}
Status status =
TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
TopPaths(top_n, &beams, &beam_log_probabilities, this->merge_repeated_);
if (!status.ok()) {
return status;
}
@ -213,20 +216,20 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
return Status::OK();
}
template <typename CTCBeamState, typename CTCBeamComparer>
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
template <typename Vector>
float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
const int K, const Vector& input, std::vector<float>* top_k_logits,
T CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::GetTopK(
const int K, const Vector& input, std::vector<T>* top_k_logits,
std::vector<int>* top_k_indices) {
// Find Top K choices, complexity nk in worst case. The array input is read
// just once.
CHECK_EQ(num_classes_, input.size());
CHECK_EQ(this->num_classes_, input.size());
top_k_logits->clear();
top_k_indices->clear();
top_k_logits->resize(K, -INFINITY);
top_k_indices->resize(K, -1);
for (int j = 0; j < num_classes_ - 1; ++j) {
const float logit = input(j);
for (int j = 0; j < this->num_classes_ - 1; ++j) {
const T logit = input(j);
if (logit > (*top_k_logits)[K - 1]) {
int k = K - 1;
while (k > 0 && logit > (*top_k_logits)[k - 1]) {
@ -239,43 +242,43 @@ float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
}
}
// Return max value which is in 0th index or blank character logit
return std::max((*top_k_logits)[0], input(num_classes_ - 1));
return std::max((*top_k_logits)[0], input(this->num_classes_ - 1));
}
template <typename CTCBeamState, typename CTCBeamComparer>
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
template <typename Vector>
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Step(
const Vector& raw_input) {
std::vector<float> top_k_logits;
std::vector<T> top_k_logits;
std::vector<int> top_k_indices;
const bool top_k =
(label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
// Number of character classes to consider in each step.
const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
const int max_classes = top_k ? label_selection_size_ : (this->num_classes_ - 1);
// Get max coefficient and remove it from raw_input later.
float max_coeff;
T max_coeff;
if (top_k) {
max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
&top_k_indices);
} else {
max_coeff = raw_input.maxCoeff();
}
// Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
float logsumexp = 0.0;
T logsumexp = T(0.0);
for (int j = 0; j < raw_input.size(); ++j) {
logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
}
logsumexp = Eigen::numext::log(logsumexp);
// Final normalization offset to get correct log probabilities.
float norm_offset = max_coeff + logsumexp;
T norm_offset = max_coeff + logsumexp;
const float label_selection_input_min =
const T label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
: -std::numeric_limits<T>::infinity();
// Extract the beams sorted in decreasing new probability
CHECK_EQ(num_classes_, raw_input.size());
CHECK_EQ(this->num_classes_, raw_input.size());
std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
leaves_.Reset();
@ -294,8 +297,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// else:
// Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
// + P(l=ab @ t=5))
float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
: b->parent->oldp.total;
T previous = (b->label == b->parent->label) ? b->parent->oldp.blank
: b->parent->oldp.total;
b->newp.label =
LogSumExp(b->newp.label,
beam_scorer_->GetStateExpansionScore(b->state, previous));
@ -304,7 +307,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
b->newp.blank = b->oldp.total + raw_input(this->blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@ -325,7 +328,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// isn't full, or the lowest probability entry in the beam has a
// lower probability than the leaf.
auto is_candidate = [this](const BeamProbability& prob) {
return (prob.total > kLogZero &&
return (prob.total > kLogZero<T>::val &&
(leaves_.size() < beam_width_ ||
prob.total > leaves_.peek_bottom()->newp.total));
};
@ -336,7 +339,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
for (int ind = 0; ind < max_classes; ind++) {
const int label = top_k ? top_k_indices[ind] : ind;
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
const T logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
// We may compare logits instead of log probabilities,
@ -347,13 +350,13 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
BeamEntry& c = b->GetChild(label);
if (!c.Active()) {
// Pblank(l=abcd @ t=6) = 0
c.newp.blank = kLogZero;
c.newp.blank = kLogZero<T>::val;
// If new child label is identical to beam label:
// Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
// Otherwise:
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
T previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
@ -379,15 +382,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} // for (BeamEntry* b...
}
template <typename CTCBeamState, typename CTCBeamComparer>
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Reset() {
leaves_.Reset();
// This beam root, and all of its children, will be in memory until
// the next reset.
beam_root_.reset(new BeamRoot(nullptr, -1));
beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
beam_root_->RootEntry()->newp.total = T(0.0); // ln(1)
beam_root_->RootEntry()->newp.blank = T(0.0); // ln(1)
// Add the root as the initial leaf.
leaves_.push(beam_root_->RootEntry());
@ -396,9 +399,9 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
}
template <typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::TopPaths(
int n, std::vector<std::vector<int>>* paths, std::vector<T>* log_probs,
bool merge_repeated) const {
CHECK_NOTNULL(paths)->clear();
CHECK_NOTNULL(log_probs)->clear();

View File

@ -0,0 +1,346 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This test illustrates how to make use of the CTCBeamSearchDecoder using a
// custom BeamScorer and BeamState based on a dictionary with a few artificial
// words.
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
#include <cmath>
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace {
typedef std::vector<std::vector<std::vector<double>>> TestData;
typedef tensorflow::ctc::CTCBeamSearchDecoder<double> CTCBeamSearchDecoder
typedef tensorflow::ctc::CTCDecoder<double> CTCDecoder;
// The HistoryBeamState is used to keep track of the current candidate and
// caches the expansion score (needed by the scorer).
struct HistoryBeamState {
double score;
std::vector<int> labels;
};
// DictionaryBeamScorer essentially favors candidates that can still become
// dictionary words. As soon as a beam candidate is not a dictionary word or
// a prefix of a dictionary word it gets a low probability at each step.
//
// The dictionary itself is hard-coded a static const variable of the class.
class DictionaryBeamScorer
: public tensorflow::ctc::BaseBeamScorer<HistoryBeamState> {
public:
void InitializeState(HistoryBeamState* root) const override {
root->score = 0;
}
void ExpandState(const HistoryBeamState& from_state, int from_label,
HistoryBeamState* to_state, int to_label) const override {
// Keep track of the current complete candidate by storing the labels along
// the expansion path in the beam state.
to_state->labels.push_back(to_label);
SetStateScoreAccordingToDict(to_state);
}
void ExpandStateEnd(HistoryBeamState* state) const override {
SetStateScoreAccordingToDict(state);
}
double GetStateExpansionScore(const HistoryBeamState& state,
double previous_score) const override {
return previous_score + state.score;
}
double GetStateEndExpansionScore(
const HistoryBeamState& state) const override {
return state.score;
}
// Simple dictionary used when scoring the beams to check if they are prefixes
// of dictionary words (see SetStateScoreAccordingToDict below).
static const std::vector<std::vector<int>> dictionary_;
private:
void SetStateScoreAccordingToDict(HistoryBeamState* state) const;
};
const std::vector<std::vector<int>> DictionaryBeamScorer::dictionary_ = {
{3}, {3, 1}};
void DictionaryBeamScorer::SetStateScoreAccordingToDict(
HistoryBeamState* state) const {
// Check if the beam can still be a dictionary word (e.g. prefix of one).
const std::vector<int>& candidate = state->labels;
for (int w = 0; w < dictionary_.size(); ++w) {
const std::vector<int>& word = dictionary_[w];
// If the length of the current beam is already larger, skip.
if (candidate.size() > word.size()) {
continue;
}
if (std::equal(word.begin(), word.begin() + candidate.size(),
candidate.begin())) {
state->score = std::log(1.0);
return;
}
}
// At this point, the candidate certainly can't be in the dictionary.
state->score = std::log(0.01);
}
TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
const int batch_size = 1;
const int timesteps = 5;
const int top_paths = 3;
const int num_classes = 6;
// Plain decoder using hibernating beam search algorithm.
CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer;
CTCBeamSearchDecoder<> decoder(num_classes, 10 * top_paths, &default_scorer);
// Dictionary decoder, allowing only two dictionary words : {3}, {3, 1}.
DictionaryBeamScorer dictionary_scorer;
CTCBeamSearchDecoder<HistoryBeamState> dictionary_decoder(
num_classes, top_paths, &dictionary_scorer);
// Raw data containers (arrays of floats64, ints, etc.).
int sequence_lengths[batch_size] = {timesteps};
double input_data_mat[timesteps][batch_size][num_classes] = {
{{0, 0.6, 0, 0.4, 0, 0}},
{{0, 0.5, 0, 0.5, 0, 0}},
{{0, 0.4, 0, 0.6, 0, 0}},
{{0, 0.4, 0, 0.6, 0, 0}},
{{0, 0.4, 0, 0.6, 0, 0}}};
// The CTCDecoder works with log-probs.
for (int t = 0; t < timesteps; ++t) {
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < num_classes; ++c) {
input_data_mat[t][b][c] = std::log(input_data_mat[t][b][c]);
}
}
}
// Plain output, without any additional scoring.
std::vector<CTCDecoder::Output> expected_output = {
{{1, 3}, {1, 3, 1}, {3, 1, 3}},
};
// Dictionary outputs: preference for dictionary candidates. The
// second-candidate is there, despite it not being a dictionary word, due to
// stronger probability in the input to the decoder.
std::vector<CTCDecoder::Output> expected_dict_output = {
{{3}, {1, 3}, {3, 1}},
};
// Convert data containers to the format accepted by the decoder, simply
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
// using Eigen::Map.
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
std::vector<Eigen::Map<const Eigen::MatrixXd>> inputs;
inputs.reserve(timesteps);
for (int t = 0; t < timesteps; ++t) {
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
}
// Prepare containers for output and scores.
std::vector<CTCDecoder::Output> outputs(top_paths);
for (CTCDecoder::Output& output : outputs) {
output.resize(batch_size);
}
double score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXd> scores(&score[0][0], batch_size, top_paths);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_output[0][path]);
}
// Prepare dictionary outputs.
std::vector<CTCDecoder::Output> dict_outputs(top_paths);
for (CTCDecoder::Output& output : dict_outputs) {
output.resize(batch_size);
}
EXPECT_TRUE(
dictionary_decoder.Decode(seq_len, inputs, &dict_outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(dict_outputs[path][0], expected_dict_output[0][path]);
}
}
TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) {
const int batch_size = 1;
const int timesteps = 1;
const int top_paths = 3;
const int num_classes = 6;
// Plain decoder using hibernating beam search algorithm.
CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer;
CTCBeamSearchDecoder<> decoder(num_classes, top_paths, &default_scorer);
// Raw data containers (arrays of floats64, ints, etc.).
int sequence_lengths[batch_size] = {timesteps};
double input_data_mat[timesteps][batch_size][num_classes] = {
{{0.4, 0.3, 0, 0, 0, 0.5}}};
// Convert data containers to the format accepted by the decoder, simply
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
// using Eigen::Map.
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
std::vector<Eigen::Map<const Eigen::MatrixXd>> inputs;
inputs.reserve(timesteps);
for (int t = 0; t < timesteps; ++t) {
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
}
// Prepare containers for output and scores.
std::vector<CTCDecoder::Output> outputs(top_paths);
for (CTCDecoder::Output& output : outputs) {
output.resize(batch_size);
}
double score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXd> scores(&score[0][0], batch_size, top_paths);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
// Make sure all scores are finite.
for (int path = 0; path < top_paths; ++path) {
LOG(INFO) << "path " << path;
EXPECT_FALSE(std::isinf(score[0][path]));
}
}
// A beam decoder to test label selection. It simply models N labels with
// rapidly dropping off log-probability.
typedef int LabelState; // The state is simply the final label.
class RapidlyDroppingLabelScorer
: public tensorflow::ctc::BaseBeamScorer<LabelState> {
public:
void InitializeState(LabelState* root) const override {}
void ExpandState(const LabelState& from_state, int from_label,
LabelState* to_state, int to_label) const override {
*to_state = to_label;
}
void ExpandStateEnd(LabelState* state) const override {}
double GetStateExpansionScore(const LabelState& state,
double previous_score) const override {
// Drop off rapidly for later labels.
const double kRapidly = 100;
return previous_score - kRapidly * state;
}
double GetStateEndExpansionScore(const LabelState& state) const override {
return 0;
}
};
TEST(CtcBeamSearch, LabelSelection) {
const int batch_size = 1;
const int timesteps = 3;
const int top_paths = 5;
const int num_classes = 6;
// Decoder which drops off log-probabilities for labels 0 >> 1 >> 2 >> 3.
RapidlyDroppingLabelScorer scorer;
CTCBeamSearchDecoder<LabelState> decoder(num_classes, top_paths, &scorer);
// Raw data containers (arrays of floats64, ints, etc.).
int sequence_lengths[batch_size] = {timesteps};
// Log probabilities, slightly preferring later labels, this decision
// should be overridden by the scorer which strongly prefers earlier labels.
// The last one is empty label, and for simplicity we give it an extremely
// high cost to ignore it. We also use the first label to break up the
// repeated label sequence.
double input_data_mat[timesteps][batch_size][num_classes] = {
{{-1e6, 1, 2, 3, 4, -1e6}},
{{1e6, 0, 0, 0, 0, -1e6}}, // force label 0 to break up repeated
{{-1e6, 1.1, 2.2, 3.3, 4.4, -1e6}},
};
// Expected output without label selection
std::vector<CTCDecoder::Output> expected_default_output = {
{{1, 0, 1}, {1, 0, 2}, {2, 0, 1}, {1, 0, 3}, {2, 0, 2}},
};
// Expected output with label selection limiting to 2 items
// this is suboptimal because only labels 3 and 4 were allowed to be seen.
std::vector<CTCDecoder::Output> expected_output_size2 = {
{{3, 0, 3}, {3, 0, 4}, {4, 0, 3}, {4, 0, 4}, {3}},
};
// Expected output with label width of 2.0. This would permit three labels at
// the first timestep, but only two at the last.
std::vector<CTCDecoder::Output> expected_output_width2 = {
{{2, 0, 3}, {2, 0, 4}, {3, 0, 3}, {3, 0, 4}, {4, 0, 3}},
};
// Convert data containers to the format accepted by the decoder, simply
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
// using Eigen::Map.
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
std::vector<Eigen::Map<const Eigen::MatrixXd>> inputs;
inputs.reserve(timesteps);
for (int t = 0; t < timesteps; ++t) {
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
}
// Prepare containers for output and scores.
std::vector<CTCDecoder::Output> outputs(top_paths);
for (CTCDecoder::Output& output : outputs) {
output.resize(batch_size);
}
double score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXd> scores(&score[0][0], batch_size, top_paths);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_default_output[0][path]);
}
// Try label selection size 2
decoder.SetLabelSelectionParameters(2, -1);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_output_size2[0][path]);
}
// Try label selection width 2.0
decoder.SetLabelSelectionParameters(0, 2.0);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_output_width2[0][path]);
}
// Try both size 2 and width 2.0: the former is more constraining, so
// it's equivalent to that.
decoder.SetLabelSelectionParameters(2, 2.0);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_output_size2[0][path]);
}
// Size 4 and width > 3.3 are equivalent to no label selection
decoder.SetLabelSelectionParameters(4, 3.3001);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_default_output[0][path]);
}
}
} // namespace

View File

@ -25,8 +25,8 @@ limitations under the License.
namespace {
typedef std::vector<std::vector<std::vector<float>>> TestData;
using tensorflow::ctc::CTCBeamSearchDecoder;
using tensorflow::ctc::CTCDecoder;
typedef tensorflow::ctc::CTCBeamSearchDecoder<float> CTCBeamSearchDecoder;
using tensorflow::ctc::CTCDecoder<double> CTCDecoder;
// The HistoryBeamState is used to keep track of the current candidate and
// caches the expansion score (needed by the scorer).
@ -115,7 +115,7 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
CTCBeamSearchDecoder<HistoryBeamState> dictionary_decoder(
num_classes, top_paths, &dictionary_scorer);
// Raw data containers (arrays of floats, ints, etc.).
// Raw data containers (arrays of floats64, ints, etc.).
int sequence_lengths[batch_size] = {timesteps};
float input_data_mat[timesteps][batch_size][num_classes] = {
{{0, 0.6, 0, 0.4, 0, 0}},
@ -149,7 +149,7 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
// using Eigen::Map.
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
std::vector<Eigen::Map<const Eigen::MatrixXd>> inputs;
inputs.reserve(timesteps);
for (int t = 0; t < timesteps; ++t) {
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
@ -161,7 +161,7 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
output.resize(batch_size);
}
float score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
Eigen::Map<Eigen::MatrixXd> scores(&score[0][0], batch_size, top_paths);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
@ -190,7 +190,7 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) {
CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer;
CTCBeamSearchDecoder<> decoder(num_classes, top_paths, &default_scorer);
// Raw data containers (arrays of floats, ints, etc.).
// Raw data containers (arrays of floats64, ints, etc.).
int sequence_lengths[batch_size] = {timesteps};
float input_data_mat[timesteps][batch_size][num_classes] = {
{{0.4, 0.3, 0, 0, 0, 0.5}}};
@ -199,7 +199,7 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) {
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
// using Eigen::Map.
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
std::vector<Eigen::Map<const Eigen::MatrixXd>> inputs;
inputs.reserve(timesteps);
for (int t = 0; t < timesteps; ++t) {
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
@ -211,7 +211,7 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) {
output.resize(batch_size);
}
float score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
Eigen::Map<Eigen::MatrixXd> scores(&score[0][0], batch_size, top_paths);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
// Make sure all scores are finite.
@ -260,7 +260,7 @@ TEST(CtcBeamSearch, LabelSelection) {
RapidlyDroppingLabelScorer scorer;
CTCBeamSearchDecoder<LabelState> decoder(num_classes, top_paths, &scorer);
// Raw data containers (arrays of floats, ints, etc.).
// Raw data containers (arrays of floats64, ints, etc.).
int sequence_lengths[batch_size] = {timesteps};
// Log probabilities, slightly preferring later labels, this decision
// should be overridden by the scorer which strongly prefers earlier labels.
@ -294,7 +294,7 @@ TEST(CtcBeamSearch, LabelSelection) {
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
// using Eigen::Map.
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
std::vector<Eigen::Map<const Eigen::MatrixXd>> inputs;
inputs.reserve(timesteps);
for (int t = 0; t < timesteps; ++t) {
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
@ -306,7 +306,7 @@ TEST(CtcBeamSearch, LabelSelection) {
output.resize(batch_size);
}
float score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
Eigen::Map<Eigen::MatrixXd> scores(&score[0][0], batch_size, top_paths);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {

View File

@ -33,12 +33,13 @@ namespace ctc {
// The two types of decoding available are:
// - greedy path, through the CTCGreedyDecoder
// - beam search, through the CTCBeamSearchDecoder
template<class T>
class CTCDecoder {
public:
typedef Eigen::Map<const Eigen::ArrayXi> SequenceLength;
typedef Eigen::Map<const Eigen::MatrixXf> Input;
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> Input;
typedef std::vector<std::vector<int>> Output;
typedef Eigen::Map<Eigen::MatrixXf> ScoreOutput;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> ScoreOutput;
CTCDecoder(int num_classes, int batch_size, bool merge_repeated)
: num_classes_(num_classes),
@ -69,25 +70,27 @@ class CTCDecoder {
// CTCGreedyDecoder is an implementation of the simple best path decoding
// algorithm, selecting at each timestep the most likely class at each timestep.
class CTCGreedyDecoder : public CTCDecoder {
template<class T>
class CTCGreedyDecoder : public CTCDecoder<T> {
public:
typedef CTCDecoder<T> Decoder;
CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated)
: CTCDecoder(num_classes, batch_size, merge_repeated) {}
: CTCDecoder<T>(num_classes, batch_size, merge_repeated) {}
Status Decode(const CTCDecoder::SequenceLength& seq_len,
const std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output,
CTCDecoder::ScoreOutput* scores) override {
if (output->empty() || (*output)[0].size() < batch_size_) {
Status Decode(const typename CTCDecoder<T>::SequenceLength& seq_len,
const std::vector<typename CTCDecoder<T>::Input>& input,
std::vector<typename CTCDecoder<T>::Output>* output,
typename CTCDecoder<T>::ScoreOutput* scores) override {
if (output->empty() || (*output)[0].size() < Decoder::batch_size_) {
return errors::InvalidArgument(
"output needs to be of size at least (1, batch_size).");
}
if (scores->rows() < batch_size_ || scores->cols() == 0) {
if (scores->rows() < Decoder::batch_size_ || scores->cols() == 0) {
return errors::InvalidArgument(
"scores needs to be of size at least (batch_size, 1).");
}
// For each batch entry, identify the transitions
for (int b = 0; b < batch_size_; ++b) {
for (int b = 0; b < Decoder::batch_size_; ++b) {
int seq_len_b = seq_len[b];
// Only writing to beam 0
std::vector<int>& output_b = (*output)[0][b];
@ -98,8 +101,8 @@ class CTCGreedyDecoder : public CTCDecoder {
auto row = input[t].row(b);
int max_class_ix;
(*scores)(b, 0) += -row.maxCoeff(&max_class_ix);
if (max_class_ix != blank_index_ &&
!(merge_repeated_ && max_class_ix == prev_class_ix)) {
if (max_class_ix != Decoder::blank_index_ &&
!(Decoder::merge_repeated_ && max_class_ix == prev_class_ix)) {
output_b.push_back(max_class_ix);
}
prev_class_ix = max_class_ix;

View File

@ -19,168 +19,6 @@ limitations under the License.
namespace tensorflow {
namespace ctc {
// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
// Starting with t = 0 instead of t = 1 used in the text.
// Based on Kanishka's CTC.
void CTCLossCalculator::CalculateForwardVariables(
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
Matrix* log_alpha) const {
// Number of cols is the number of time steps = number of cols in target
// after the output delay.
log_alpha->setConstant(kLogZero);
int U = l_prime.size();
int T = log_alpha->cols();
CHECK_EQ(U, log_alpha->rows());
// Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
log_alpha->coeffRef(0, 0) = std::log(y(blank_index_, output_delay_));
// Below, l_prime[1] == labels[0]
auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
log_alpha->coeffRef(1, 0) = std::log(y(label_0, output_delay_));
for (int t = 1; t < T; ++t) {
// If there is not enough time to output the remaining labels or
// some labels have been skipped, then let log_alpha(u, t) continue to
// be kLogZero.
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
++u) {
// Begin (GravesTh) Eq 7.9
// Add in the u, t - 1 term.
float sum_log_alpha = kLogZero;
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
sum_log_alpha = log_alpha->coeff(u, t - 1);
}
// Add in the u - 1, t - 1 term.
if (u > 0) {
sum_log_alpha =
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
}
// Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
if (u > 1) {
const bool matching_labels_merge =
ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
sum_log_alpha =
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
}
}
// Multiply the summed alphas with the activation log probability.
log_alpha->coeffRef(u, t) =
std::log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
} // End (GravesTh) Eq 7.9.
}
}
// Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
void CTCLossCalculator::CalculateBackwardVariables(
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
Matrix* log_beta) const {
// Number of cols is the number of time steps = number of cols in target.
// Matrix log_beta =
// Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
// kLogZero);
log_beta->setConstant(kLogZero);
int T = log_beta->cols();
int U = l_prime.size();
CHECK_EQ(U, log_beta->rows());
// Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
for (int t = T - 1 - 1; t >= 0; --t) {
// If there is not enough time to output the remaining labels or
// some labels have been skipped, then let log_beta(u, t) continue to
// be kLogZero.
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
++u) {
// Begin (GravesTh) Eq 7.15
// Add in the u, t + 1 term.
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u, t + 1) +
std::log(y(l_prime[u], output_delay_ + t + 1)));
}
// Add in the u + 1, t + 1 term.
if (u + 1 < U) {
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u + 1, t + 1) +
std::log(y(l_prime[u + 1], output_delay_ + t + 1)));
}
// Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
if (u + 2 < U) {
const bool matching_labels_merge =
ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
// Add in u + 2 term.
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u + 2, t + 1) +
std::log(y(l_prime[u + 2], output_delay_ + t + 1)));
}
} // End (GravesTh) Eq. 7.15
}
}
}
// Using (GravesTh) Eq 7.26 & 7.34.
void CTCLossCalculator::CalculateGradient(const std::vector<int>& l_prime,
const Matrix& y,
const Matrix& log_alpha,
const Matrix& log_beta,
float log_p_z_x, Matrix* dy) const {
// Only working with the leftmost part of dy for this batch element.
auto dy_b = dy->leftCols(y.cols());
// It is possible that no valid path is found if the activations for the
// targets are zero.
if (log_p_z_x == kLogZero) {
LOG(WARNING) << "No valid path found.";
dy_b = y;
return;
}
int L = y.rows();
int T = y.cols();
int U = l_prime.size();
for (int t = 0; t < T - output_delay_; ++t) {
Array prob_sum(L);
prob_sum.setConstant(kLogZero);
for (int u = 0; u < U; ++u) {
int l = l_prime[u];
prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
}
for (int l = 0; l < L; ++l) {
// Negative term in (GravesTh) Eq 7.28.
float negative_term = expf(prob_sum[l] - log_p_z_x);
dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
}
}
}
void CTCLossCalculator::GetLPrimeIndices(const std::vector<int>& l,
std::vector<int>* l_prime) const {
// Assumption is that l_prime is empty.
l_prime->reserve(2 * l.size() + 1);
for (auto label : l) {
l_prime->push_back(blank_index_);
l_prime->push_back(label);
}
// Add final blank to l'.
l_prime->push_back(blank_index_);
}
} // namespace ctc
} // namespace tensorflow

View File

@ -30,6 +30,7 @@ limitations under the License.
namespace tensorflow {
namespace ctc {
template<class T>
class CTCLossCalculator {
// Connectionist Temporal Classification Loss
//
@ -50,10 +51,14 @@ class CTCLossCalculator {
// Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen.
public:
typedef std::vector<std::vector<int>> LabelSequences;
typedef Eigen::MatrixXf Matrix;
typedef Eigen::ArrayXf Array;
typedef Eigen::Map<const Eigen::MatrixXf> InputMap;
typedef Eigen::Map<Eigen::MatrixXf> OutputMap;
using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
// typedef Eigen::MatrixXd Matrix;
using Array = Eigen::Array<T, Eigen::Dynamic, 1>;
// typedef Eigen::ArrayXd Array;
using InputMap = Eigen::Map<const Matrix>;
// typedef Eigen::Map<const Eigen::MatrixXd> InputMap;
using OutputMap = Eigen::Map<Matrix>;
// typedef Eigen::Map<Eigen::MatrixXd> OutputMap;
CTCLossCalculator(int blank_index, int output_delay)
: blank_index_(blank_index), output_delay_(output_delay) {}
@ -79,7 +84,7 @@ class CTCLossCalculator {
void CalculateGradient(const std::vector<int>& l_prime, const Matrix& y,
const Matrix& log_alpha, const Matrix& log_beta,
float log_p_z_x, Matrix* dy) const;
T log_p_z_x, Matrix* dy) const;
void GetLPrimeIndices(const std::vector<int>& l,
std::vector<int>* l_prime) const;
@ -103,9 +108,10 @@ class CTCLossCalculator {
const int output_delay_;
};
template <class T>
template <typename VectorIn, typename VectorOut, typename MatrixIn,
typename MatrixOut>
Status CTCLossCalculator::CalculateLoss(
Status CTCLossCalculator<T>::CalculateLoss(
const VectorIn& seq_len, const LabelSequences& labels,
const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated,
bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs,
@ -205,11 +211,11 @@ Status CTCLossCalculator::CalculateLoss(
// Convert label from DistBelief
// y, prob are in num_classes x seq_len(b)
// Output activations.
Eigen::ArrayXf y_b_col;
Array y_b_col;
for (int t = 0; t < seq_len(b); t++) {
// Calculate the softmax of y_b. Use double precision
// Calculate the softmax of y_b. Use original precision
// arithmetic for the sum.
float max_coeff = inputs[t].row(b).maxCoeff();
T max_coeff = inputs[t].row(b).maxCoeff();
y_b_col = (inputs[t].row(b).array() - max_coeff).exp();
y_b.col(t) = y_b_col / y_b_col.sum();
}
@ -222,7 +228,7 @@ Status CTCLossCalculator::CalculateLoss(
// The loss is computed as the log(p(z|x)) between the target and
// prediction. Do lazy evaluation of log_prob here.
float log_p_z_x = kLogZero;
T log_p_z_x = kLogZero<T>::val;
for (int u = 0; u < l_prime.size(); ++u) {
// (GravesTh) Eq 7.26, sum over all paths for t = 0.
log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
@ -253,19 +259,19 @@ Status CTCLossCalculator::CalculateLoss(
// fwd,bwd: T * 2 * (2*L + 1) * (Cost(LogSumExp) + Cost(Log)) +
// grad: T * ((2L + 1) * Cost(LogSumExp) + L * (Cost(Expf) + Cost(Add)).
const int64 cost_exp = Eigen::internal::functor_traits<
Eigen::internal::scalar_exp_op<float>>::Cost;
Eigen::internal::scalar_exp_op<T>>::Cost;
const int64 cost_log = Eigen::internal::functor_traits<
Eigen::internal::scalar_log_op<float>>::Cost;
Eigen::internal::scalar_log_op<T>>::Cost;
const int64 cost_log_sum_exp =
Eigen::TensorOpCost::AddCost<float>() + cost_exp + cost_log;
Eigen::TensorOpCost::AddCost<T>() + cost_exp + cost_log;
const int64 cost =
max_seq_len * num_classes *
(cost_exp + Eigen::TensorOpCost::DivCost<float>()) +
(cost_exp + Eigen::TensorOpCost::DivCost<T>()) +
max_seq_len * 2 * (2 * num_classes + 1) *
(cost_log_sum_exp + cost_log) +
max_seq_len *
((2 * num_classes + 1) * cost_log_sum_exp +
num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<float>()));
num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<T>()));
Shard(workers->num_threads, workers->workers, batch_size, cost,
ComputeLossAndGradients);
} else {
@ -274,8 +280,9 @@ Status CTCLossCalculator::CalculateLoss(
return Status::OK();
}
template <class T>
template <typename Vector>
Status CTCLossCalculator::PopulateLPrimes(
Status CTCLossCalculator<T>::PopulateLPrimes(
bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs,
int batch_size, int num_classes, const Vector& seq_len,
const LabelSequences& labels, size_t* max_u_prime,
@ -357,6 +364,174 @@ Status CTCLossCalculator::PopulateLPrimes(
return Status::OK();
}
// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
// Starting with t = 0 instead of t = 1 used in the text.
// Based on Kanishka's CTC.
template<typename TT>
void CTCLossCalculator<TT>::CalculateForwardVariables(
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
Matrix* log_alpha) const {
// Number of cols is the number of time steps = number of cols in target
// after the output delay.
log_alpha->setConstant(kLogZero<TT>::val);
int U = l_prime.size();
int T = log_alpha->cols();
CHECK_EQ(U, log_alpha->rows());
// Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_));
// Below, l_prime[1] == labels[0]
auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_));
for (int t = 1; t < T; ++t) {
// If there is not enough time to output the remaining labels or
// some labels have been skipped, then let log_alpha(u, t) continue to
// be kLogZero.
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
++u) {
// Begin (GravesTh) Eq 7.9
// Add in the u, t - 1 term.
auto sum_log_alpha = kLogZero<TT>::val;
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
sum_log_alpha = log_alpha->coeff(u, t - 1);
}
// Add in the u - 1, t - 1 term.
if (u > 0) {
sum_log_alpha =
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
}
// Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
if (u > 1) {
const bool matching_labels_merge =
ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
sum_log_alpha =
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
}
}
// Multiply the summed alphas with the activation log probability.
log_alpha->coeffRef(u, t) =
log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
} // End (GravesTh) Eq 7.9.
}
}
// Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
template <class TT>
void CTCLossCalculator<TT>::CalculateBackwardVariables(
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
Matrix* log_beta) const {
// Number of cols is the number of time steps = number of cols in target.
// Matrix log_beta =
// Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
// kLogZero);
log_beta->setConstant(kLogZero<TT>::val);
int T = log_beta->cols();
int U = l_prime.size();
CHECK_EQ(U, log_beta->rows());
// Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
for (int t = T - 1 - 1; t >= 0; --t) {
// If there is not enough time to output the remaining labels or
// some labels have been skipped, then let log_beta(u, t) continue to
// be kLogZero.
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
++u) {
// Begin (GravesTh) Eq 7.15
// Add in the u, t + 1 term.
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u, t + 1) +
log(y(l_prime[u], output_delay_ + t + 1)));
}
// Add in the u + 1, t + 1 term.
if (u + 1 < U) {
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u + 1, t + 1) +
log(y(l_prime[u + 1], output_delay_ + t + 1)));
}
// Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
if (u + 2 < U) {
const bool matching_labels_merge =
ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
// Add in u + 2 term.
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u + 2, t + 1) +
log(y(l_prime[u + 2], output_delay_ + t + 1)));
}
} // End (GravesTh) Eq. 7.15
}
}
}
// Using (GravesTh) Eq 7.26 & 7.34.
template <typename TT>
void CTCLossCalculator<TT>::CalculateGradient(const std::vector<int>& l_prime,
const Matrix& y,
const Matrix& log_alpha,
const Matrix& log_beta,
TT log_p_z_x, Matrix* dy) const {
// Only working with the leftmost part of dy for this batch element.
auto dy_b = dy->leftCols(y.cols());
// It is possible that no valid path is found if the activations for the
// targets are zero.
if (log_p_z_x == kLogZero<TT>::val) {
LOG(WARNING) << "No valid path found.";
dy_b = y;
return;
}
int L = y.rows();
int T = y.cols();
int U = l_prime.size();
for (int t = 0; t < T - output_delay_; ++t) {
Array prob_sum(L);
prob_sum.setConstant(kLogZero<TT>::val);
for (int u = 0; u < U; ++u) {
int l = l_prime[u];
prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
}
for (int l = 0; l < L; ++l) {
// Negative term in (GravesTh) Eq 7.28.
auto negative_term = expf(prob_sum[l] - log_p_z_x);
dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
}
}
}
template <class TT>
void CTCLossCalculator<TT>::GetLPrimeIndices(const std::vector<int>& l,
std::vector<int>* l_prime) const {
// Assumption is that l_prime is empty.
l_prime->reserve(2 * l.size() + 1);
for (auto label : l) {
l_prime->push_back(blank_index_);
l_prime->push_back(label);
}
// Add final blank to l'.
l_prime->push_back(blank_index_);
}
} // namespace ctc
} // namespace tensorflow

View File

@ -23,18 +23,24 @@ limitations under the License.
namespace tensorflow {
namespace ctc {
const float kLogZero = -std::numeric_limits<float>::infinity();
template <class T>
struct kLogZero
{
static constexpr T val = -std::numeric_limits<T>::infinity();
};
// Add logarithmic probabilities using:
// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a)))
// The two inputs are assumed to be log probabilities.
// (GravesTh) Eq. 7.18
inline float LogSumExp(float log_prob_1, float log_prob_2) {
template<typename T>
inline T LogSumExp(T log_prob_1, T log_prob_2) {
//const T kLogZero = -std::numeric_limits<T>::infinity();
// Always have 'b' be the smaller number to avoid the exponential from
// blowing up.
if (log_prob_1 == kLogZero) {
if (log_prob_1 == kLogZero<T>::val) {
return log_prob_2;
} else if (log_prob_2 == kLogZero) {
} else if (log_prob_2 == kLogZero<T>::val) {
return log_prob_1;
} else {
return (log_prob_1 > log_prob_2)