Merge pull request #31164 from aprimostka:ctc-float64

PiperOrigin-RevId: 264272190
This commit is contained in:
TensorFlower Gardener 2019-08-19 17:31:36 -07:00
commit fe5cf47131
11 changed files with 535 additions and 425 deletions

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
@ -33,11 +34,12 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
int* c) {
template <typename T>
inline T RowMax(const typename TTypes<T>::UnalignedConstMatrix& m, int r,
int* c) {
*c = 0;
CHECK_LT(0, m.dimension(1));
float p = m(r, 0);
auto p = m(r, 0);
for (int i = 1; i < m.dimension(1); ++i) {
if (m(r, i) > p) {
p = m(r, i);
@ -170,6 +172,7 @@ class CTCDecodeHelper {
TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};
template <typename T>
class CTCGreedyDecoderOp : public OpKernel {
public:
explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -189,7 +192,7 @@ class CTCGreedyDecoderOp : public OpKernel {
const TensorShape& inputs_shape = inputs->shape();
std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
const int64 max_time = inputs_shape.dim_size(0);
const int64 batch_size = inputs_shape.dim_size(1);
const int64 num_classes_raw = inputs_shape.dim_size(2);
@ -198,14 +201,14 @@ class CTCGreedyDecoderOp : public OpKernel {
errors::InvalidArgument("num_classes cannot exceed max int"));
const int num_classes = static_cast<const int>(num_classes_raw);
auto inputs_t = inputs->tensor<float, 3>();
auto inputs_t = inputs->tensor<T, 3>();
for (std::size_t t = 0; t < max_time; ++t) {
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
batch_size, num_classes);
}
auto seq_len_t = seq_len->vec<int32>();
auto log_prob_t = log_prob->matrix<float>();
auto log_prob_t = log_prob->matrix<T>();
log_prob_t.setZero();
@ -221,7 +224,8 @@ class CTCGreedyDecoderOp : public OpKernel {
int prev_indices = -1;
for (int t = 0; t < seq_len_t(b); ++t) {
int max_class_indices;
log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
log_prob_t(b, 0) +=
-RowMax<T>(input_list_t[t], b, &max_class_indices);
if (max_class_indices != blank_index &&
!(merge_repeated_ && max_class_indices == prev_indices)) {
sequence.push_back(max_class_indices);
@ -250,10 +254,18 @@ class CTCGreedyDecoderOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
};
REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU),
CTCGreedyDecoderOp);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CTCGreedyDecoderOp<T>);
REGISTER_CPU(float);
REGISTER_CPU(double);
#undef REGISTER_CPU
// CTC beam search
template <typename T>
class CTCBeamSearchDecoderOp : public OpKernel {
public:
explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -275,9 +287,9 @@ class CTCBeamSearchDecoderOp : public OpKernel {
ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
&decoded_values, &decoded_shape));
auto inputs_t = inputs->tensor<float, 3>();
auto inputs_t = inputs->tensor<T, 3>();
auto seq_len_t = seq_len->vec<int32>();
auto log_prob_t = log_prob->matrix<float>();
auto log_prob_t = log_prob->matrix<T>();
const TensorShape& inputs_shape = inputs->shape();
@ -291,21 +303,21 @@ class CTCBeamSearchDecoderOp : public OpKernel {
log_prob_t.setZero();
std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
std::vector<typename TTypes<T>::UnalignedConstMatrix> input_list_t;
for (std::size_t t = 0; t < max_time; ++t) {
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
batch_size, num_classes);
}
ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
&beam_scorer_, 1 /* batch_size */,
merge_repeated_);
Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
auto input_chip_t = input_chip.flat<float>();
ctc::CTCBeamSearchDecoder<T> beam_search(num_classes, beam_width_,
&beam_scorer_, 1 /* batch_size */,
merge_repeated_);
Tensor input_chip(DataTypeToEnum<T>::v(), TensorShape({num_classes}));
auto input_chip_t = input_chip.flat<T>();
std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
std::vector<float> log_probs;
std::vector<T> log_probs;
// Assumption: the blank index is num_classes - 1
for (int b = 0; b < batch_size; ++b) {
@ -313,8 +325,8 @@ class CTCBeamSearchDecoderOp : public OpKernel {
best_paths_b.resize(decode_helper_.GetTopPaths());
for (int t = 0; t < seq_len_t(b); ++t) {
input_chip_t = input_list_t[t].chip(b, 0);
auto input_bi =
Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
auto input_bi = Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>(
input_chip_t.data(), num_classes);
beam_search.Step(input_bi);
}
OP_REQUIRES_OK(
@ -335,13 +347,20 @@ class CTCBeamSearchDecoderOp : public OpKernel {
private:
CTCDecodeHelper decode_helper_;
ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
typename ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer beam_scorer_;
bool merge_repeated_;
int beam_width_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp);
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp<T>);
};
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU),
CTCBeamSearchDecoderOp);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CTCBeamSearchDecoderOp<T>);
REGISTER_CPU(float);
REGISTER_CPU(double);
#undef REGISTER_CPU
} // end namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -26,14 +27,13 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename T>
class CTCLossOp : public OpKernel {
typedef Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor> >
typedef Eigen::Map<
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
InputMap;
typedef Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
OutputMap;
public:
@ -110,7 +110,7 @@ class CTCLossOp : public OpKernel {
errors::InvalidArgument("label SparseTensor is not valid: ",
labels_sp_valid.error_message()));
ctc::CTCLossCalculator::LabelSequences labels_t(batch_size);
typename ctc::CTCLossCalculator<T>::LabelSequences labels_t(batch_size);
for (const auto& g : labels_sp.group({0})) { // iterate by batch
const int64 batch_indices = g.group()[0];
OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size),
@ -137,13 +137,13 @@ class CTCLossOp : public OpKernel {
Tensor* loss = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
auto loss_t = loss->vec<float>();
auto loss_t = loss->vec<T>();
Tensor* gradient;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("gradient", inputs_shape, &gradient));
auto gradient_t = gradient->tensor<float, 3>();
auto inputs_t = inputs->tensor<float, 3>();
auto gradient_t = gradient->tensor<T, 3>();
auto inputs_t = inputs->tensor<T, 3>();
std::vector<OutputMap> gradient_list_t;
std::vector<InputMap> input_list_t;
@ -158,7 +158,7 @@ class CTCLossOp : public OpKernel {
gradient_t.setZero();
// Assumption: the blank index is num_classes - 1
ctc::CTCLossCalculator ctc_loss_calculator(num_classes - 1, 0);
ctc::CTCLossCalculator<T> ctc_loss_calculator(num_classes - 1, 0);
DeviceBase::CpuWorkerThreads workers =
*ctx->device()->tensorflow_cpu_worker_threads();
OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
@ -173,9 +173,17 @@ class CTCLossOp : public OpKernel {
bool ctc_merge_repeated_;
bool ignore_longer_outputs_than_inputs_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp);
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp<T>);
};
REGISTER_KERNEL_BUILDER(Name("CTCLoss").Device(DEVICE_CPU), CTCLossOp);
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("CTCLoss").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CTCLossOp<T>);
REGISTER_CPU(float);
REGISTER_CPU(double);
#undef REGISTER_CPU
} // end namespace tensorflow

View File

@ -25,15 +25,16 @@ using shape_inference::ShapeHandle;
// CTC is Connectionist Temporal Classification. See util/ctc/ for details.
REGISTER_OP("CTCLoss")
.Input("inputs: float")
.Input("inputs: T")
.Input("labels_indices: int64")
.Input("labels_values: int32")
.Input("sequence_length: int32")
.Attr("preprocess_collapse_repeated: bool = false")
.Attr("ctc_merge_repeated: bool = true")
.Attr("ignore_longer_outputs_than_inputs: bool = false")
.Output("loss: float")
.Output("gradient: float")
.Output("loss: T")
.Output("gradient: T")
.Attr("T: {float, double} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle labels_indices;
@ -62,13 +63,14 @@ REGISTER_OP("CTCLoss")
});
REGISTER_OP("CTCGreedyDecoder")
.Input("inputs: float")
.Input("inputs: T")
.Input("sequence_length: int32")
.Attr("merge_repeated: bool = false")
.Output("decoded_indices: int64")
.Output("decoded_values: int64")
.Output("decoded_shape: int64")
.Output("log_probability: float")
.Output("log_probability: T")
.Attr("T: {float, double} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle sequence_length;
@ -90,7 +92,7 @@ REGISTER_OP("CTCGreedyDecoder")
});
REGISTER_OP("CTCBeamSearchDecoder")
.Input("inputs: float")
.Input("inputs: T")
.Input("sequence_length: int32")
.Attr("beam_width: int >= 1")
.Attr("top_paths: int >= 1")
@ -98,7 +100,8 @@ REGISTER_OP("CTCBeamSearchDecoder")
.Output("decoded_indices: top_paths * int64")
.Output("decoded_values: top_paths * int64")
.Output("decoded_shape: top_paths * int64")
.Output("log_probability: float")
.Output("log_probability: T")
.Attr("T: {float, double} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle sequence_length;

View File

@ -41,30 +41,34 @@ namespace ctc_beam_search {
struct EmptyBeamState {};
template <typename T>
struct BeamProbability {
BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {}
BeamProbability()
: total(kLogZero<T>::val),
blank(kLogZero<T>::val),
label(kLogZero<T>::val) {}
void Reset() {
total = kLogZero;
blank = kLogZero;
label = kLogZero;
total = kLogZero<T>::val;
blank = kLogZero<T>::val;
label = kLogZero<T>::val;
}
float total;
float blank;
float label;
T total;
T blank;
T label;
};
template <class CTCBeamState>
template <class T, class CTCBeamState>
class BeamRoot;
template <class CTCBeamState = EmptyBeamState>
template <class T, class CTCBeamState = EmptyBeamState>
struct BeamEntry {
// BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
BeamEntry<CTCBeamState>* p, int l);
inline bool Active() const { return newp.total != kLogZero; }
friend BeamEntry<T, CTCBeamState>* BeamRoot<T, CTCBeamState>::AddEntry(
BeamEntry<T, CTCBeamState>* p, int l);
inline bool Active() const { return newp.total != kLogZero<T>::val; }
// Return the child at the given index, or construct a new one in-place if
// none was found.
BeamEntry& GetChild(int ind) {
BeamEntry<T, CTCBeamState>& GetChild(int ind) {
auto entry = children.emplace(ind, nullptr);
auto& child_entry = entry.first->second;
// If this is a new child, populate the BeamEntry<CTCBeamState>*.
@ -76,7 +80,7 @@ struct BeamEntry {
std::vector<int> LabelSeq(bool merge_repeated) const {
std::vector<int> labels;
int prev_label = -1;
const BeamEntry* c = this;
const BeamEntry<T, CTCBeamState>* c = this;
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
if (!merge_repeated || c->label != prev_label) {
labels.push_back(c->label);
@ -88,12 +92,12 @@ struct BeamEntry {
return labels;
}
BeamEntry<CTCBeamState>* parent;
BeamEntry<T, CTCBeamState>* parent;
int label;
// All instances of child BeamEntry are owned by *beam_root.
gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children;
BeamProbability oldp;
BeamProbability newp;
gtl::FlatMap<int, BeamEntry<T, CTCBeamState>*> children;
BeamProbability<T> oldp;
BeamProbability<T> newp;
CTCBeamState state;
private:
@ -102,40 +106,42 @@ struct BeamEntry {
// otherwise parent will become invalid.
// This private constructor is only called through the factory method
// BeamRoot<CTCBeamState>::AddEntry().
BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
BeamEntry(BeamEntry* p, int l, BeamRoot<T, CTCBeamState>* beam_root)
: parent(p), label(l), beam_root(beam_root) {}
BeamRoot<CTCBeamState>* beam_root;
BeamRoot<T, CTCBeamState>* beam_root;
TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
};
// This class owns all instances of BeamEntry. This is used to avoid recursive
// destructor call during destruction.
template <class CTCBeamState = EmptyBeamState>
template <class T, class CTCBeamState = EmptyBeamState>
class BeamRoot {
public:
BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
BeamRoot(BeamEntry<T, CTCBeamState>* p, int l) {
root_entry_ = AddEntry(p, l);
}
BeamRoot(const BeamRoot&) = delete;
BeamRoot& operator=(const BeamRoot&) = delete;
BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
BeamEntry<T, CTCBeamState>* AddEntry(BeamEntry<T, CTCBeamState>* p, int l) {
auto* new_entry = new BeamEntry<T, CTCBeamState>(p, l, this);
beam_entries_.emplace_back(new_entry);
return new_entry;
}
BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
BeamEntry<T, CTCBeamState>* RootEntry() const { return root_entry_; }
private:
BeamEntry<CTCBeamState>* root_entry_ = nullptr;
std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
BeamEntry<T, CTCBeamState>* root_entry_ = nullptr;
std::vector<std::unique_ptr<BeamEntry<T, CTCBeamState>>> beam_entries_;
};
// BeamComparer is the default beam comparer provided in CTCBeamSearch.
template <class CTCBeamState = EmptyBeamState>
template <class T, class CTCBeamState = EmptyBeamState>
class BeamComparer {
public:
virtual ~BeamComparer() {}
virtual bool inline operator()(const BeamEntry<CTCBeamState>* a,
const BeamEntry<CTCBeamState>* b) const {
virtual bool inline operator()(const BeamEntry<T, CTCBeamState>* a,
const BeamEntry<T, CTCBeamState>* b) const {
return a->newp.total > b->newp.total;
}
};

View File

@ -34,7 +34,7 @@ namespace ctc {
// be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex
// scoring is required. Its main purpose is to provide a thin layer for
// integrating language model scoring easily.
template <typename CTCBeamState>
template <typename T, typename CTCBeamState>
class BaseBeamScorer {
public:
virtual ~BaseBeamScorer() {}
@ -56,8 +56,8 @@ class BaseBeamScorer {
//
// The score returned should be a log-probability. In the simplest case, as
// there's no state expansion logic, the expansion score is zero.
virtual float GetStateExpansionScore(const CTCBeamState& state,
float previous_score) const {
virtual T GetStateExpansionScore(const CTCBeamState& state,
T previous_score) const {
return previous_score;
}
// GetStateEndExpansionScore should be an inexpensive method to retrieve the
@ -65,8 +65,8 @@ class BaseBeamScorer {
// multiplied (log-addition) with the final probability of the beam.
//
// The score returned should be a log-probability.
virtual float GetStateEndExpansionScore(const CTCBeamState& state) const {
return 0;
virtual T GetStateEndExpansionScore(const CTCBeamState& state) const {
return T(0);
}
};

View File

@ -38,10 +38,10 @@ limitations under the License.
namespace tensorflow {
namespace ctc {
template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
template <typename T, typename CTCBeamState = ctc_beam_search::EmptyBeamState,
typename CTCBeamComparer =
ctc_beam_search::BeamComparer<CTCBeamState>>
class CTCBeamSearchDecoder : public CTCDecoder {
ctc_beam_search::BeamComparer<T, CTCBeamState>>
class CTCBeamSearchDecoder : public CTCDecoder<T> {
// Beam Search
//
// Example (GravesTh Fig. 7.5):
@ -73,12 +73,12 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// starts at 0). This special case can be calculated as:
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
// but we calculate it recursively for speed purposes.
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
typedef ctc_beam_search::BeamProbability BeamProbability;
typedef ctc_beam_search::BeamEntry<T, CTCBeamState> BeamEntry;
typedef ctc_beam_search::BeamRoot<T, CTCBeamState> BeamRoot;
typedef ctc_beam_search::BeamProbability<T> BeamProbability;
public:
typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
typedef BaseBeamScorer<T, CTCBeamState> DefaultBeamScorer;
// The beam search decoder is constructed specifying the beam_width (number of
// candidates to keep at each decoding timestep) and a beam scorer (used for
@ -87,9 +87,9 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
// standard beam search.
CTCBeamSearchDecoder(int num_classes, int beam_width,
BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
bool merge_repeated = false)
: CTCDecoder(num_classes, batch_size, merge_repeated),
BaseBeamScorer<T, CTCBeamState>* scorer,
int batch_size = 1, bool merge_repeated = false)
: CTCDecoder<T>(num_classes, batch_size, merge_repeated),
beam_width_(beam_width),
leaves_(beam_width),
beam_scorer_(CHECK_NOTNULL(scorer)) {
@ -99,27 +99,28 @@ class CTCBeamSearchDecoder : public CTCDecoder {
~CTCBeamSearchDecoder() override {}
// Run the hibernating beam search algorithm on the given input.
Status Decode(const CTCDecoder::SequenceLength& seq_len,
const std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output,
CTCDecoder::ScoreOutput* scores) override;
Status Decode(const typename CTCDecoder<T>::SequenceLength& seq_len,
const std::vector<typename CTCDecoder<T>::Input>& input,
std::vector<typename CTCDecoder<T>::Output>* output,
typename CTCDecoder<T>::ScoreOutput* scores) override;
// Calculate the next step of the beam search and update the internal state.
template <typename Vector>
void Step(const Vector& log_input_t);
template <typename Vector>
float GetTopK(const int K, const Vector& input,
std::vector<float>* top_k_logits,
std::vector<int>* top_k_indices);
T GetTopK(const int K, const Vector& input, std::vector<T>* top_k_logits,
std::vector<int>* top_k_indices);
// Retrieve the beam scorer instance used during decoding.
BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
BaseBeamScorer<T, CTCBeamState>* GetBeamScorer() const {
return beam_scorer_;
}
// Set label selection parameters for faster decoding.
// See comments for label_selection_size_ and label_selection_margin_.
void SetLabelSelectionParameters(int label_selection_size,
float label_selection_margin) {
T label_selection_margin) {
label_selection_size_ = label_selection_size;
label_selection_margin_ = label_selection_margin;
}
@ -129,7 +130,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// Extract the top n paths at current time step
Status TopPaths(int n, std::vector<std::vector<int>>* paths,
std::vector<float>* log_probs, bool merge_repeated) const;
std::vector<T>* log_probs, bool merge_repeated) const;
private:
int beam_width_;
@ -145,37 +146,38 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// Default is to do no label selection.
// For more detail: https://research.google.com/pubs/pub44823.html
int label_selection_size_ = 0; // zero means unlimited
float label_selection_margin_ = -1; // -1 means unlimited.
T label_selection_margin_ = -1; // -1 means unlimited.
gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
std::unique_ptr<BeamRoot> beam_root_;
BaseBeamScorer<CTCBeamState>* beam_scorer_;
BaseBeamScorer<T, CTCBeamState>* beam_scorer_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
};
template <typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
const CTCDecoder::SequenceLength& seq_len,
const std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Decode(
const typename CTCDecoder<T>::SequenceLength& seq_len,
const std::vector<typename CTCDecoder<T>::Input>& input,
std::vector<typename CTCDecoder<T>::Output>* output,
typename CTCDecoder<T>::ScoreOutput* scores) {
// Storage for top paths.
std::vector<std::vector<int>> beams;
std::vector<float> beam_log_probabilities;
std::vector<T> beam_log_probabilities;
int top_n = output->size();
if (std::any_of(output->begin(), output->end(),
[this](const CTCDecoder::Output& output) -> bool {
[this](const typename CTCDecoder<T>::Output& output) -> bool {
return output.size() < this->batch_size_;
})) {
return errors::InvalidArgument(
"output needs to be of size at least (top_n, batch_size).");
}
if (scores->rows() < batch_size_ || scores->cols() < top_n) {
if (scores->rows() < this->batch_size_ || scores->cols() < top_n) {
return errors::InvalidArgument(
"scores needs to be of size at least (batch_size, top_n).");
}
for (int b = 0; b < batch_size_; ++b) {
for (int b = 0; b < this->batch_size_; ++b) {
int seq_len_b = seq_len[b];
Reset();
@ -196,7 +198,7 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
}
Status status =
TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
TopPaths(top_n, &beams, &beam_log_probabilities, this->merge_repeated_);
if (!status.ok()) {
return status;
}
@ -213,20 +215,20 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
return Status::OK();
}
template <typename CTCBeamState, typename CTCBeamComparer>
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
template <typename Vector>
float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
const int K, const Vector& input, std::vector<float>* top_k_logits,
T CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::GetTopK(
const int K, const Vector& input, std::vector<T>* top_k_logits,
std::vector<int>* top_k_indices) {
// Find Top K choices, complexity nk in worst case. The array input is read
// just once.
CHECK_EQ(num_classes_, input.size());
CHECK_EQ(this->num_classes_, input.size());
top_k_logits->clear();
top_k_indices->clear();
top_k_logits->resize(K, -INFINITY);
top_k_indices->resize(K, -1);
for (int j = 0; j < num_classes_ - 1; ++j) {
const float logit = input(j);
for (int j = 0; j < this->num_classes_ - 1; ++j) {
const T logit = input(j);
if (logit > (*top_k_logits)[K - 1]) {
int k = K - 1;
while (k > 0 && logit > (*top_k_logits)[k - 1]) {
@ -239,43 +241,43 @@ float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
}
}
// Return max value which is in 0th index or blank character logit
return std::max((*top_k_logits)[0], input(num_classes_ - 1));
return std::max((*top_k_logits)[0], input(this->num_classes_ - 1));
}
template <typename CTCBeamState, typename CTCBeamComparer>
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
template <typename Vector>
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Step(
const Vector& raw_input) {
std::vector<float> top_k_logits;
std::vector<T> top_k_logits;
std::vector<int> top_k_indices;
const bool top_k =
(label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
// Number of character classes to consider in each step.
const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
const int max_classes =
top_k ? label_selection_size_ : (this->num_classes_ - 1);
// Get max coefficient and remove it from raw_input later.
float max_coeff;
T max_coeff;
if (top_k) {
max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
&top_k_indices);
} else {
max_coeff = raw_input.maxCoeff();
}
// Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
float logsumexp = 0.0;
T logsumexp = T(0.0);
for (int j = 0; j < raw_input.size(); ++j) {
logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
}
logsumexp = Eigen::numext::log(logsumexp);
// Final normalization offset to get correct log probabilities.
float norm_offset = max_coeff + logsumexp;
T norm_offset = max_coeff + logsumexp;
const float label_selection_input_min =
const T label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
: -std::numeric_limits<T>::infinity();
// Extract the beams sorted in decreasing new probability
CHECK_EQ(num_classes_, raw_input.size());
CHECK_EQ(this->num_classes_, raw_input.size());
std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
leaves_.Reset();
@ -294,8 +296,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// else:
// Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
// + P(l=ab @ t=5))
float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
: b->parent->oldp.total;
T previous = (b->label == b->parent->label) ? b->parent->oldp.blank
: b->parent->oldp.total;
b->newp.label =
LogSumExp(b->newp.label,
beam_scorer_->GetStateExpansionScore(b->state, previous));
@ -304,7 +306,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
b->newp.blank = b->oldp.total + raw_input(this->blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@ -325,7 +327,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// isn't full, or the lowest probability entry in the beam has a
// lower probability than the leaf.
auto is_candidate = [this](const BeamProbability& prob) {
return (prob.total > kLogZero &&
return (prob.total > kLogZero<T>::val &&
(leaves_.size() < beam_width_ ||
prob.total > leaves_.peek_bottom()->newp.total));
};
@ -336,7 +338,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
for (int ind = 0; ind < max_classes; ind++) {
const int label = top_k ? top_k_indices[ind] : ind;
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
const T logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
// We may compare logits instead of log probabilities,
@ -347,13 +349,13 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
BeamEntry& c = b->GetChild(label);
if (!c.Active()) {
// Pblank(l=abcd @ t=6) = 0
c.newp.blank = kLogZero;
c.newp.blank = kLogZero<T>::val;
// If new child label is identical to beam label:
// Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
// Otherwise:
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
T previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
@ -379,15 +381,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} // for (BeamEntry* b...
}
template <typename CTCBeamState, typename CTCBeamComparer>
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
void CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::Reset() {
leaves_.Reset();
// This beam root, and all of its children, will be in memory until
// the next reset.
beam_root_.reset(new BeamRoot(nullptr, -1));
beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
beam_root_->RootEntry()->newp.total = T(0.0); // ln(1)
beam_root_->RootEntry()->newp.blank = T(0.0); // ln(1)
// Add the root as the initial leaf.
leaves_.push(beam_root_->RootEntry());
@ -396,9 +398,9 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
}
template <typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
template <typename T, typename CTCBeamState, typename CTCBeamComparer>
Status CTCBeamSearchDecoder<T, CTCBeamState, CTCBeamComparer>::TopPaths(
int n, std::vector<std::vector<int>>* paths, std::vector<T>* log_probs,
bool merge_repeated) const {
CHECK_NOTNULL(paths)->clear();
CHECK_NOTNULL(log_probs)->clear();

View File

@ -19,19 +19,20 @@ limitations under the License.
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
#include <cmath>
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace {
typedef std::vector<std::vector<std::vector<float>>> TestData;
using tensorflow::ctc::CTCBeamSearchDecoder;
using tensorflow::ctc::CTCDecoder;
template <class T>
using TestData = std::vector<std::vector<std::vector<T>>>;
// The HistoryBeamState is used to keep track of the current candidate and
// caches the expansion score (needed by the scorer).
template <class T>
struct HistoryBeamState {
float score;
T score;
std::vector<int> labels;
};
@ -40,48 +41,48 @@ struct HistoryBeamState {
// a prefix of a dictionary word it gets a low probability at each step.
//
// The dictionary itself is hard-coded a static const variable of the class.
template <class T, class BeamState>
class DictionaryBeamScorer
: public tensorflow::ctc::BaseBeamScorer<HistoryBeamState> {
: public tensorflow::ctc::BaseBeamScorer<T, BeamState> {
public:
void InitializeState(HistoryBeamState* root) const override {
root->score = 0;
}
DictionaryBeamScorer()
: tensorflow::ctc::BaseBeamScorer<T, BeamState>(),
dictionary_({{3}, {3, 1}}) {}
void ExpandState(const HistoryBeamState& from_state, int from_label,
HistoryBeamState* to_state, int to_label) const override {
void InitializeState(BeamState* root) const override { root->score = 0; }
void ExpandState(const BeamState& from_state, int from_label,
BeamState* to_state, int to_label) const override {
// Keep track of the current complete candidate by storing the labels along
// the expansion path in the beam state.
to_state->labels.push_back(to_label);
SetStateScoreAccordingToDict(to_state);
}
void ExpandStateEnd(HistoryBeamState* state) const override {
void ExpandStateEnd(BeamState* state) const override {
SetStateScoreAccordingToDict(state);
}
float GetStateExpansionScore(const HistoryBeamState& state,
float previous_score) const override {
T GetStateExpansionScore(const BeamState& state,
T previous_score) const override {
return previous_score + state.score;
}
float GetStateEndExpansionScore(
const HistoryBeamState& state) const override {
T GetStateEndExpansionScore(const BeamState& state) const override {
return state.score;
}
// Simple dictionary used when scoring the beams to check if they are prefixes
// of dictionary words (see SetStateScoreAccordingToDict below).
static const std::vector<std::vector<int>> dictionary_;
const std::vector<std::vector<int>> dictionary_;
private:
void SetStateScoreAccordingToDict(HistoryBeamState* state) const;
void SetStateScoreAccordingToDict(BeamState* state) const;
};
const std::vector<std::vector<int>> DictionaryBeamScorer::dictionary_ = {
{3}, {3, 1}};
void DictionaryBeamScorer::SetStateScoreAccordingToDict(
HistoryBeamState* state) const {
template <class T, class BeamState>
void DictionaryBeamScorer<T, BeamState>::SetStateScoreAccordingToDict(
BeamState* state) const {
// Check if the beam can still be a dictionary word (e.g. prefix of one).
const std::vector<int>& candidate = state->labels;
for (int w = 0; w < dictionary_.size(); ++w) {
@ -92,32 +93,35 @@ void DictionaryBeamScorer::SetStateScoreAccordingToDict(
}
if (std::equal(word.begin(), word.begin() + candidate.size(),
candidate.begin())) {
state->score = std::log(1.0);
state->score = std::log(T(1.0));
return;
}
}
// At this point, the candidate certainly can't be in the dictionary.
state->score = std::log(0.01);
state->score = std::log(T(0.01));
}
TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
template <class T>
void ctc_beam_search_decoding_with_and_without_dictionary() {
const int batch_size = 1;
const int timesteps = 5;
const int top_paths = 3;
const int num_classes = 6;
// Plain decoder using hibernating beam search algorithm.
CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer;
CTCBeamSearchDecoder<> decoder(num_classes, 10 * top_paths, &default_scorer);
typename tensorflow::ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer
default_scorer;
tensorflow::ctc::CTCBeamSearchDecoder<T> decoder(num_classes, 10 * top_paths,
&default_scorer);
// Dictionary decoder, allowing only two dictionary words : {3}, {3, 1}.
DictionaryBeamScorer dictionary_scorer;
CTCBeamSearchDecoder<HistoryBeamState> dictionary_decoder(
num_classes, top_paths, &dictionary_scorer);
DictionaryBeamScorer<T, HistoryBeamState<T>> dictionary_scorer;
tensorflow::ctc::CTCBeamSearchDecoder<T, HistoryBeamState<T>>
dictionary_decoder(num_classes, top_paths, &dictionary_scorer);
// Raw data containers (arrays of floats, ints, etc.).
// Raw data containers (arrays of floats64, ints, etc.).
int sequence_lengths[batch_size] = {timesteps};
float input_data_mat[timesteps][batch_size][num_classes] = {
T input_data_mat[timesteps][batch_size][num_classes] = {
{{0, 0.6, 0, 0.4, 0, 0}},
{{0, 0.5, 0, 0.5, 0, 0}},
{{0, 0.4, 0, 0.6, 0, 0}},
@ -134,34 +138,40 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
}
// Plain output, without any additional scoring.
std::vector<CTCDecoder::Output> expected_output = {
{{1, 3}, {1, 3, 1}, {3, 1, 3}},
};
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> expected_output =
{
{{1, 3}, {1, 3, 1}, {3, 1, 3}},
};
// Dictionary outputs: preference for dictionary candidates. The
// second-candidate is there, despite it not being a dictionary word, due to
// stronger probability in the input to the decoder.
std::vector<CTCDecoder::Output> expected_dict_output = {
{{3}, {1, 3}, {3, 1}},
};
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output>
expected_dict_output = {
{{3}, {1, 3}, {3, 1}},
};
// Convert data containers to the format accepted by the decoder, simply
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
// using Eigen::Map.
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
std::vector<
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>>
inputs;
inputs.reserve(timesteps);
for (int t = 0; t < timesteps; ++t) {
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
}
// Prepare containers for output and scores.
std::vector<CTCDecoder::Output> outputs(top_paths);
for (CTCDecoder::Output& output : outputs) {
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> outputs(
top_paths);
for (typename tensorflow::ctc::CTCDecoder<T>::Output& output : outputs) {
output.resize(batch_size);
}
float score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
T score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> scores(
&score[0][0], batch_size, top_paths);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
@ -169,8 +179,9 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
}
// Prepare dictionary outputs.
std::vector<CTCDecoder::Output> dict_outputs(top_paths);
for (CTCDecoder::Output& output : dict_outputs) {
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> dict_outputs(
top_paths);
for (typename tensorflow::ctc::CTCDecoder<T>::Output& output : dict_outputs) {
output.resize(batch_size);
}
EXPECT_TRUE(
@ -180,38 +191,45 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
}
}
TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) {
template <class T>
void ctc_beam_search_decoding_all_beam_elements_have_finite_scores() {
const int batch_size = 1;
const int timesteps = 1;
const int top_paths = 3;
const int num_classes = 6;
// Plain decoder using hibernating beam search algorithm.
CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer;
CTCBeamSearchDecoder<> decoder(num_classes, top_paths, &default_scorer);
typename tensorflow::ctc::CTCBeamSearchDecoder<T>::DefaultBeamScorer
default_scorer;
tensorflow::ctc::CTCBeamSearchDecoder<T> decoder(num_classes, top_paths,
&default_scorer);
// Raw data containers (arrays of floats, ints, etc.).
// Raw data containers (arrays of floats64, ints, etc.).
int sequence_lengths[batch_size] = {timesteps};
float input_data_mat[timesteps][batch_size][num_classes] = {
T input_data_mat[timesteps][batch_size][num_classes] = {
{{0.4, 0.3, 0, 0, 0, 0.5}}};
// Convert data containers to the format accepted by the decoder, simply
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
// using Eigen::Map.
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
std::vector<
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>>
inputs;
inputs.reserve(timesteps);
for (int t = 0; t < timesteps; ++t) {
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
}
// Prepare containers for output and scores.
std::vector<CTCDecoder::Output> outputs(top_paths);
for (CTCDecoder::Output& output : outputs) {
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> outputs(
top_paths);
for (typename tensorflow::ctc::CTCDecoder<T>::Output& output : outputs) {
output.resize(batch_size);
}
float score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
T score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> scores(
&score[0][0], batch_size, top_paths);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
// Make sure all scores are finite.
@ -226,8 +244,9 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) {
typedef int LabelState; // The state is simply the final label.
template <class T>
class RapidlyDroppingLabelScorer
: public tensorflow::ctc::BaseBeamScorer<LabelState> {
: public tensorflow::ctc::BaseBeamScorer<T, LabelState> {
public:
void InitializeState(LabelState* root) const override {}
@ -238,75 +257,84 @@ class RapidlyDroppingLabelScorer
void ExpandStateEnd(LabelState* state) const override {}
float GetStateExpansionScore(const LabelState& state,
float previous_score) const override {
T GetStateExpansionScore(const LabelState& state,
T previous_score) const override {
// Drop off rapidly for later labels.
const float kRapidly = 100;
const T kRapidly = 100;
return previous_score - kRapidly * state;
}
float GetStateEndExpansionScore(const LabelState& state) const override {
return 0;
T GetStateEndExpansionScore(const LabelState& state) const override {
return T(0);
}
};
TEST(CtcBeamSearch, LabelSelection) {
template <class T>
void ctc_beam_search_label_selection() {
const int batch_size = 1;
const int timesteps = 3;
const int top_paths = 5;
const int num_classes = 6;
// Decoder which drops off log-probabilities for labels 0 >> 1 >> 2 >> 3.
RapidlyDroppingLabelScorer scorer;
CTCBeamSearchDecoder<LabelState> decoder(num_classes, top_paths, &scorer);
RapidlyDroppingLabelScorer<T> scorer;
tensorflow::ctc::CTCBeamSearchDecoder<T, LabelState> decoder(
num_classes, top_paths, &scorer);
// Raw data containers (arrays of floats, ints, etc.).
// Raw data containers (arrays of floats64, ints, etc.).
int sequence_lengths[batch_size] = {timesteps};
// Log probabilities, slightly preferring later labels, this decision
// should be overridden by the scorer which strongly prefers earlier labels.
// The last one is empty label, and for simplicity we give it an extremely
// high cost to ignore it. We also use the first label to break up the
// repeated label sequence.
float input_data_mat[timesteps][batch_size][num_classes] = {
T input_data_mat[timesteps][batch_size][num_classes] = {
{{-1e6, 1, 2, 3, 4, -1e6}},
{{1e6, 0, 0, 0, 0, -1e6}}, // force label 0 to break up repeated
{{-1e6, 1.1, 2.2, 3.3, 4.4, -1e6}},
};
// Expected output without label selection
std::vector<CTCDecoder::Output> expected_default_output = {
{{1, 0, 1}, {1, 0, 2}, {2, 0, 1}, {1, 0, 3}, {2, 0, 2}},
};
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output>
expected_default_output = {
{{1, 0, 1}, {1, 0, 2}, {2, 0, 1}, {1, 0, 3}, {2, 0, 2}},
};
// Expected output with label selection limiting to 2 items
// this is suboptimal because only labels 3 and 4 were allowed to be seen.
std::vector<CTCDecoder::Output> expected_output_size2 = {
{{3, 0, 3}, {3, 0, 4}, {4, 0, 3}, {4, 0, 4}, {3}},
};
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output>
expected_output_size2 = {
{{3, 0, 3}, {3, 0, 4}, {4, 0, 3}, {4, 0, 4}, {3}},
};
// Expected output with label width of 2.0. This would permit three labels at
// the first timestep, but only two at the last.
std::vector<CTCDecoder::Output> expected_output_width2 = {
{{2, 0, 3}, {2, 0, 4}, {3, 0, 3}, {3, 0, 4}, {4, 0, 3}},
};
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output>
expected_output_width2 = {
{{2, 0, 3}, {2, 0, 4}, {3, 0, 3}, {3, 0, 4}, {4, 0, 3}},
};
// Convert data containers to the format accepted by the decoder, simply
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
// using Eigen::Map.
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
std::vector<
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>>
inputs;
inputs.reserve(timesteps);
for (int t = 0; t < timesteps; ++t) {
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
}
// Prepare containers for output and scores.
std::vector<CTCDecoder::Output> outputs(top_paths);
for (CTCDecoder::Output& output : outputs) {
std::vector<typename tensorflow::ctc::CTCDecoder<T>::Output> outputs(
top_paths);
for (typename tensorflow::ctc::CTCDecoder<T>::Output& output : outputs) {
output.resize(batch_size);
}
float score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
T score[batch_size][top_paths] = {{0.0}};
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> scores(
&score[0][0], batch_size, top_paths);
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
@ -314,14 +342,14 @@ TEST(CtcBeamSearch, LabelSelection) {
}
// Try label selection size 2
decoder.SetLabelSelectionParameters(2, -1);
decoder.SetLabelSelectionParameters(2, T(-1));
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_output_size2[0][path]);
}
// Try label selection width 2.0
decoder.SetLabelSelectionParameters(0, 2.0);
decoder.SetLabelSelectionParameters(0, T(2.0));
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_output_width2[0][path]);
@ -329,18 +357,42 @@ TEST(CtcBeamSearch, LabelSelection) {
// Try both size 2 and width 2.0: the former is more constraining, so
// it's equivalent to that.
decoder.SetLabelSelectionParameters(2, 2.0);
decoder.SetLabelSelectionParameters(2, T(2.0));
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_output_size2[0][path]);
}
// Size 4 and width > 3.3 are equivalent to no label selection
decoder.SetLabelSelectionParameters(4, 3.3001);
decoder.SetLabelSelectionParameters(4, T(3.3001));
EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok());
for (int path = 0; path < top_paths; ++path) {
EXPECT_EQ(outputs[path][0], expected_default_output[0][path]);
}
}
TEST(CtcBeamSearch, FloatDecodingWithAndWithoutDictionary) {
ctc_beam_search_decoding_with_and_without_dictionary<float>();
}
TEST(CtcBeamSearch, DoubleDecodingWithAndWithoutDictionary) {
ctc_beam_search_decoding_with_and_without_dictionary<double>();
}
TEST(CtcBeamSearch, FloatAllBeamElementsHaveFiniteScores) {
ctc_beam_search_decoding_all_beam_elements_have_finite_scores<float>();
}
TEST(CtcBeamSearch, DoubleAllBeamElementsHaveFiniteScores) {
ctc_beam_search_decoding_all_beam_elements_have_finite_scores<double>();
}
TEST(CtcBeamSearch, FloatLabelSelection) {
ctc_beam_search_label_selection<float>();
}
TEST(CtcBeamSearch, DoubleLabelSelection) {
ctc_beam_search_label_selection<double>();
}
} // namespace

View File

@ -33,12 +33,15 @@ namespace ctc {
// The two types of decoding available are:
// - greedy path, through the CTCGreedyDecoder
// - beam search, through the CTCBeamSearchDecoder
template <class T>
class CTCDecoder {
public:
typedef Eigen::Map<const Eigen::ArrayXi> SequenceLength;
typedef Eigen::Map<const Eigen::MatrixXf> Input;
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
Input;
typedef std::vector<std::vector<int>> Output;
typedef Eigen::Map<Eigen::MatrixXf> ScoreOutput;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ScoreOutput;
CTCDecoder(int num_classes, int batch_size, bool merge_repeated)
: num_classes_(num_classes),
@ -69,25 +72,27 @@ class CTCDecoder {
// CTCGreedyDecoder is an implementation of the simple best path decoding
// algorithm, selecting at each timestep the most likely class at each timestep.
class CTCGreedyDecoder : public CTCDecoder {
template <class T>
class CTCGreedyDecoder : public CTCDecoder<T> {
public:
typedef CTCDecoder<T> Decoder;
CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated)
: CTCDecoder(num_classes, batch_size, merge_repeated) {}
: CTCDecoder<T>(num_classes, batch_size, merge_repeated) {}
Status Decode(const CTCDecoder::SequenceLength& seq_len,
const std::vector<CTCDecoder::Input>& input,
std::vector<CTCDecoder::Output>* output,
CTCDecoder::ScoreOutput* scores) override {
if (output->empty() || (*output)[0].size() < batch_size_) {
Status Decode(const typename CTCDecoder<T>::SequenceLength& seq_len,
const std::vector<typename CTCDecoder<T>::Input>& input,
std::vector<typename CTCDecoder<T>::Output>* output,
typename CTCDecoder<T>::ScoreOutput* scores) override {
if (output->empty() || (*output)[0].size() < Decoder::batch_size_) {
return errors::InvalidArgument(
"output needs to be of size at least (1, batch_size).");
}
if (scores->rows() < batch_size_ || scores->cols() == 0) {
if (scores->rows() < Decoder::batch_size_ || scores->cols() == 0) {
return errors::InvalidArgument(
"scores needs to be of size at least (batch_size, 1).");
}
// For each batch entry, identify the transitions
for (int b = 0; b < batch_size_; ++b) {
for (int b = 0; b < Decoder::batch_size_; ++b) {
int seq_len_b = seq_len[b];
// Only writing to beam 0
std::vector<int>& output_b = (*output)[0][b];
@ -98,8 +103,8 @@ class CTCGreedyDecoder : public CTCDecoder {
auto row = input[t].row(b);
int max_class_ix;
(*scores)(b, 0) += -row.maxCoeff(&max_class_ix);
if (max_class_ix != blank_index_ &&
!(merge_repeated_ && max_class_ix == prev_class_ix)) {
if (max_class_ix != Decoder::blank_index_ &&
!(Decoder::merge_repeated_ && max_class_ix == prev_class_ix)) {
output_b.push_back(max_class_ix);
}
prev_class_ix = max_class_ix;

View File

@ -14,173 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
#include <cmath>
namespace tensorflow {
namespace ctc {
// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
// Starting with t = 0 instead of t = 1 used in the text.
// Based on Kanishka's CTC.
void CTCLossCalculator::CalculateForwardVariables(
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
Matrix* log_alpha) const {
// Number of cols is the number of time steps = number of cols in target
// after the output delay.
log_alpha->setConstant(kLogZero);
int U = l_prime.size();
int T = log_alpha->cols();
CHECK_EQ(U, log_alpha->rows());
// Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
log_alpha->coeffRef(0, 0) = std::log(y(blank_index_, output_delay_));
// Below, l_prime[1] == labels[0]
auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
log_alpha->coeffRef(1, 0) = std::log(y(label_0, output_delay_));
for (int t = 1; t < T; ++t) {
// If there is not enough time to output the remaining labels or
// some labels have been skipped, then let log_alpha(u, t) continue to
// be kLogZero.
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
++u) {
// Begin (GravesTh) Eq 7.9
// Add in the u, t - 1 term.
float sum_log_alpha = kLogZero;
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
sum_log_alpha = log_alpha->coeff(u, t - 1);
}
// Add in the u - 1, t - 1 term.
if (u > 0) {
sum_log_alpha =
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
}
// Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
if (u > 1) {
const bool matching_labels_merge =
ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
sum_log_alpha =
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
}
}
// Multiply the summed alphas with the activation log probability.
log_alpha->coeffRef(u, t) =
std::log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
} // End (GravesTh) Eq 7.9.
}
}
// Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
void CTCLossCalculator::CalculateBackwardVariables(
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
Matrix* log_beta) const {
// Number of cols is the number of time steps = number of cols in target.
// Matrix log_beta =
// Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
// kLogZero);
log_beta->setConstant(kLogZero);
int T = log_beta->cols();
int U = l_prime.size();
CHECK_EQ(U, log_beta->rows());
// Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
for (int t = T - 1 - 1; t >= 0; --t) {
// If there is not enough time to output the remaining labels or
// some labels have been skipped, then let log_beta(u, t) continue to
// be kLogZero.
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
++u) {
// Begin (GravesTh) Eq 7.15
// Add in the u, t + 1 term.
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u, t + 1) +
std::log(y(l_prime[u], output_delay_ + t + 1)));
}
// Add in the u + 1, t + 1 term.
if (u + 1 < U) {
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u + 1, t + 1) +
std::log(y(l_prime[u + 1], output_delay_ + t + 1)));
}
// Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
if (u + 2 < U) {
const bool matching_labels_merge =
ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
// Add in u + 2 term.
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u + 2, t + 1) +
std::log(y(l_prime[u + 2], output_delay_ + t + 1)));
}
} // End (GravesTh) Eq. 7.15
}
}
}
// Using (GravesTh) Eq 7.26 & 7.34.
void CTCLossCalculator::CalculateGradient(const std::vector<int>& l_prime,
const Matrix& y,
const Matrix& log_alpha,
const Matrix& log_beta,
float log_p_z_x, Matrix* dy) const {
// Only working with the leftmost part of dy for this batch element.
auto dy_b = dy->leftCols(y.cols());
// It is possible that no valid path is found if the activations for the
// targets are zero.
if (log_p_z_x == kLogZero) {
LOG(WARNING) << "No valid path found.";
dy_b = y;
return;
}
int L = y.rows();
int T = y.cols();
int U = l_prime.size();
for (int t = 0; t < T - output_delay_; ++t) {
Array prob_sum(L);
prob_sum.setConstant(kLogZero);
for (int u = 0; u < U; ++u) {
int l = l_prime[u];
prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
}
for (int l = 0; l < L; ++l) {
// Negative term in (GravesTh) Eq 7.28.
float negative_term = expf(prob_sum[l] - log_p_z_x);
dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
}
}
}
void CTCLossCalculator::GetLPrimeIndices(const std::vector<int>& l,
std::vector<int>* l_prime) const {
// Assumption is that l_prime is empty.
l_prime->reserve(2 * l.size() + 1);
for (auto label : l) {
l_prime->push_back(blank_index_);
l_prime->push_back(label);
}
// Add final blank to l'.
l_prime->push_back(blank_index_);
}
} // namespace ctc
namespace ctc {} // namespace ctc
} // namespace tensorflow

View File

@ -30,6 +30,7 @@ limitations under the License.
namespace tensorflow {
namespace ctc {
template <class T>
class CTCLossCalculator {
// Connectionist Temporal Classification Loss
//
@ -50,10 +51,14 @@ class CTCLossCalculator {
// Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen.
public:
typedef std::vector<std::vector<int>> LabelSequences;
typedef Eigen::MatrixXf Matrix;
typedef Eigen::ArrayXf Array;
typedef Eigen::Map<const Eigen::MatrixXf> InputMap;
typedef Eigen::Map<Eigen::MatrixXf> OutputMap;
using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
// typedef Eigen::MatrixXd Matrix;
using Array = Eigen::Array<T, Eigen::Dynamic, 1>;
// typedef Eigen::ArrayXd Array;
using InputMap = Eigen::Map<const Matrix>;
// typedef Eigen::Map<const Eigen::MatrixXd> InputMap;
using OutputMap = Eigen::Map<Matrix>;
// typedef Eigen::Map<Eigen::MatrixXd> OutputMap;
CTCLossCalculator(int blank_index, int output_delay)
: blank_index_(blank_index), output_delay_(output_delay) {}
@ -79,7 +84,7 @@ class CTCLossCalculator {
void CalculateGradient(const std::vector<int>& l_prime, const Matrix& y,
const Matrix& log_alpha, const Matrix& log_beta,
float log_p_z_x, Matrix* dy) const;
T log_p_z_x, Matrix* dy) const;
void GetLPrimeIndices(const std::vector<int>& l,
std::vector<int>* l_prime) const;
@ -103,9 +108,10 @@ class CTCLossCalculator {
const int output_delay_;
};
template <class T>
template <typename VectorIn, typename VectorOut, typename MatrixIn,
typename MatrixOut>
Status CTCLossCalculator::CalculateLoss(
Status CTCLossCalculator<T>::CalculateLoss(
const VectorIn& seq_len, const LabelSequences& labels,
const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated,
bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs,
@ -205,11 +211,11 @@ Status CTCLossCalculator::CalculateLoss(
// Convert label from DistBelief
// y, prob are in num_classes x seq_len(b)
// Output activations.
Eigen::ArrayXf y_b_col;
Array y_b_col;
for (int t = 0; t < seq_len(b); t++) {
// Calculate the softmax of y_b. Use double precision
// Calculate the softmax of y_b. Use original precision
// arithmetic for the sum.
float max_coeff = inputs[t].row(b).maxCoeff();
T max_coeff = inputs[t].row(b).maxCoeff();
y_b_col = (inputs[t].row(b).array() - max_coeff).exp();
y_b.col(t) = y_b_col / y_b_col.sum();
}
@ -222,7 +228,7 @@ Status CTCLossCalculator::CalculateLoss(
// The loss is computed as the log(p(z|x)) between the target and
// prediction. Do lazy evaluation of log_prob here.
float log_p_z_x = kLogZero;
T log_p_z_x = kLogZero<T>::val;
for (int u = 0; u < l_prime.size(); ++u) {
// (GravesTh) Eq 7.26, sum over all paths for t = 0.
log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
@ -253,19 +259,19 @@ Status CTCLossCalculator::CalculateLoss(
// fwd,bwd: T * 2 * (2*L + 1) * (Cost(LogSumExp) + Cost(Log)) +
// grad: T * ((2L + 1) * Cost(LogSumExp) + L * (Cost(Expf) + Cost(Add)).
const int64 cost_exp = Eigen::internal::functor_traits<
Eigen::internal::scalar_exp_op<float>>::Cost;
Eigen::internal::scalar_exp_op<T>>::Cost;
const int64 cost_log = Eigen::internal::functor_traits<
Eigen::internal::scalar_log_op<float>>::Cost;
Eigen::internal::scalar_log_op<T>>::Cost;
const int64 cost_log_sum_exp =
Eigen::TensorOpCost::AddCost<float>() + cost_exp + cost_log;
Eigen::TensorOpCost::AddCost<T>() + cost_exp + cost_log;
const int64 cost =
max_seq_len * num_classes *
(cost_exp + Eigen::TensorOpCost::DivCost<float>()) +
(cost_exp + Eigen::TensorOpCost::DivCost<T>()) +
max_seq_len * 2 * (2 * num_classes + 1) *
(cost_log_sum_exp + cost_log) +
max_seq_len *
((2 * num_classes + 1) * cost_log_sum_exp +
num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<float>()));
num_classes * (cost_exp + Eigen::TensorOpCost::AddCost<T>()));
Shard(workers->num_threads, workers->workers, batch_size, cost,
ComputeLossAndGradients);
} else {
@ -274,8 +280,9 @@ Status CTCLossCalculator::CalculateLoss(
return Status::OK();
}
template <class T>
template <typename Vector>
Status CTCLossCalculator::PopulateLPrimes(
Status CTCLossCalculator<T>::PopulateLPrimes(
bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs,
int batch_size, int num_classes, const Vector& seq_len,
const LabelSequences& labels, size_t* max_u_prime,
@ -357,6 +364,173 @@ Status CTCLossCalculator::PopulateLPrimes(
return Status::OK();
}
// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
// Starting with t = 0 instead of t = 1 used in the text.
// Based on Kanishka's CTC.
template <typename TT>
void CTCLossCalculator<TT>::CalculateForwardVariables(
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
Matrix* log_alpha) const {
// Number of cols is the number of time steps = number of cols in target
// after the output delay.
log_alpha->setConstant(kLogZero<TT>::val);
int U = l_prime.size();
int T = log_alpha->cols();
CHECK_EQ(U, log_alpha->rows());
// Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_));
// Below, l_prime[1] == labels[0]
auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_));
for (int t = 1; t < T; ++t) {
// If there is not enough time to output the remaining labels or
// some labels have been skipped, then let log_alpha(u, t) continue to
// be kLogZero.
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
++u) {
// Begin (GravesTh) Eq 7.9
// Add in the u, t - 1 term.
auto sum_log_alpha = kLogZero<TT>::val;
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
sum_log_alpha = log_alpha->coeff(u, t - 1);
}
// Add in the u - 1, t - 1 term.
if (u > 0) {
sum_log_alpha =
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
}
// Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
if (u > 1) {
const bool matching_labels_merge =
ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
sum_log_alpha =
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
}
}
// Multiply the summed alphas with the activation log probability.
log_alpha->coeffRef(u, t) =
log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
} // End (GravesTh) Eq 7.9.
}
}
// Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
template <class TT>
void CTCLossCalculator<TT>::CalculateBackwardVariables(
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
Matrix* log_beta) const {
// Number of cols is the number of time steps = number of cols in target.
// Matrix log_beta =
// Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
// kLogZero);
log_beta->setConstant(kLogZero<TT>::val);
int T = log_beta->cols();
int U = l_prime.size();
CHECK_EQ(U, log_beta->rows());
// Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
for (int t = T - 1 - 1; t >= 0; --t) {
// If there is not enough time to output the remaining labels or
// some labels have been skipped, then let log_beta(u, t) continue to
// be kLogZero.
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
++u) {
// Begin (GravesTh) Eq 7.15
// Add in the u, t + 1 term.
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u, t + 1) +
log(y(l_prime[u], output_delay_ + t + 1)));
}
// Add in the u + 1, t + 1 term.
if (u + 1 < U) {
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u + 1, t + 1) +
log(y(l_prime[u + 1], output_delay_ + t + 1)));
}
// Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
if (u + 2 < U) {
const bool matching_labels_merge =
ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
// Add in u + 2 term.
log_beta->coeffRef(u, t) =
LogSumExp(log_beta->coeff(u, t),
log_beta->coeff(u + 2, t + 1) +
log(y(l_prime[u + 2], output_delay_ + t + 1)));
}
} // End (GravesTh) Eq. 7.15
}
}
}
// Using (GravesTh) Eq 7.26 & 7.34.
template <typename TT>
void CTCLossCalculator<TT>::CalculateGradient(const std::vector<int>& l_prime,
const Matrix& y,
const Matrix& log_alpha,
const Matrix& log_beta,
TT log_p_z_x, Matrix* dy) const {
// Only working with the leftmost part of dy for this batch element.
auto dy_b = dy->leftCols(y.cols());
// It is possible that no valid path is found if the activations for the
// targets are zero.
if (log_p_z_x == kLogZero<TT>::val) {
LOG(WARNING) << "No valid path found.";
dy_b = y;
return;
}
int L = y.rows();
int T = y.cols();
int U = l_prime.size();
for (int t = 0; t < T - output_delay_; ++t) {
Array prob_sum(L);
prob_sum.setConstant(kLogZero<TT>::val);
for (int u = 0; u < U; ++u) {
int l = l_prime[u];
prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
}
for (int l = 0; l < L; ++l) {
// Negative term in (GravesTh) Eq 7.28.
auto negative_term = expf(prob_sum[l] - log_p_z_x);
dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
}
}
}
template <class TT>
void CTCLossCalculator<TT>::GetLPrimeIndices(const std::vector<int>& l,
std::vector<int>* l_prime) const {
// Assumption is that l_prime is empty.
l_prime->reserve(2 * l.size() + 1);
for (auto label : l) {
l_prime->push_back(blank_index_);
l_prime->push_back(label);
}
// Add final blank to l'.
l_prime->push_back(blank_index_);
}
} // namespace ctc
} // namespace tensorflow

View File

@ -23,18 +23,23 @@ limitations under the License.
namespace tensorflow {
namespace ctc {
const float kLogZero = -std::numeric_limits<float>::infinity();
template <class T>
struct kLogZero {
static constexpr T val = -std::numeric_limits<T>::infinity(); // NOLINT
};
// Add logarithmic probabilities using:
// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a)))
// The two inputs are assumed to be log probabilities.
// (GravesTh) Eq. 7.18
inline float LogSumExp(float log_prob_1, float log_prob_2) {
template <typename T>
inline T LogSumExp(T log_prob_1, T log_prob_2) {
// const T kLogZero = -std::numeric_limits<T>::infinity();
// Always have 'b' be the smaller number to avoid the exponential from
// blowing up.
if (log_prob_1 == kLogZero) {
if (log_prob_1 == kLogZero<T>::val) {
return log_prob_2;
} else if (log_prob_2 == kLogZero) {
} else if (log_prob_2 == kLogZero<T>::val) {
return log_prob_1;
} else {
return (log_prob_1 > log_prob_2)