Add CTC (Connectionist Temporal Classification) Ops to TF contrib.
This includes: * ctc_loss * ctc_greedy_decoder * ctc_beam_search_decoder Change: 115683564
This commit is contained in:
parent
fdfbd3af7a
commit
8509f88ade
tensorflow
@ -12,6 +12,7 @@ py_library(
|
||||
srcs = glob(["**/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/ctc:ctc_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Add projects here, they will show up under tf.contrib.
|
||||
from tensorflow.contrib import ctc
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
from tensorflow.contrib import testing
|
||||
|
53
tensorflow/contrib/ctc/BUILD
Normal file
53
tensorflow/contrib/ctc/BUILD
Normal file
@ -0,0 +1,53 @@
|
||||
# Description:
|
||||
# contains parts of TensorFlow that are experimental or unstable and which are not supported.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
|
||||
py_library(
|
||||
name = "ctc_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"ctc_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "ctc_decoder_ops_test",
|
||||
srcs = ["ctc_decoder_ops_test.py"],
|
||||
additional_deps = [
|
||||
":ctc_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ctc_loss_op_test",
|
||||
srcs = ["ctc_loss_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ctc_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
29
tensorflow/contrib/ctc/__init__.py
Normal file
29
tensorflow/contrib/ctc/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Ops for CTC (Connectionist Temporal Classification).
|
||||
|
||||
@@ctc_loss
|
||||
@@ctc_greedy_decoder
|
||||
@@ctc_beam_search_decoder
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.ctc.ctc_ops import *
|
200
tensorflow/contrib/ctc/ctc_decoder_ops_test.py
Normal file
200
tensorflow/contrib/ctc/ctc_decoder_ops_test.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for tensorflow.ctc_ops.ctc_loss_op."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def grouper(iterable, n, fillvalue=None):
|
||||
"""Collect data into fixed-length chunks or blocks."""
|
||||
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
|
||||
args = [iter(iterable)] * n
|
||||
return itertools.izip_longest(fillvalue=fillvalue, *args)
|
||||
|
||||
|
||||
def flatten(list_of_lists):
|
||||
"""Flatten one level of nesting."""
|
||||
return itertools.chain.from_iterable(list_of_lists)
|
||||
|
||||
|
||||
class CTCGreedyDecoderTest(tf.test.TestCase):
|
||||
|
||||
def _testCTCDecoder(self, decoder, inputs, seq_lens, log_prob_truth,
|
||||
decode_truth, expected_err_re=None, **decoder_args):
|
||||
inputs_t = [tf.convert_to_tensor(x) for x in inputs]
|
||||
# convert inputs_t into a [max_time x batch_size x depth] tensor
|
||||
# from a len time python list of [batch_size x depth] tensors
|
||||
inputs_t = tf.pack(inputs_t)
|
||||
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
decoded_list, log_probability = decoder(
|
||||
inputs_t,
|
||||
sequence_length=seq_lens, **decoder_args)
|
||||
decoded_unwrapped = list(flatten([
|
||||
(st.indices, st.values, st.shape) for st in decoded_list]))
|
||||
|
||||
if expected_err_re is None:
|
||||
outputs = sess.run(
|
||||
decoded_unwrapped + [log_probability])
|
||||
|
||||
# Group outputs into (ix, vals, shape) tuples
|
||||
output_sparse_tensors = list(grouper(outputs[:-1], 3))
|
||||
|
||||
output_log_probability = outputs[-1]
|
||||
|
||||
# Check the number of decoded outputs (top_paths) match
|
||||
self.assertEqual(len(output_sparse_tensors), len(decode_truth))
|
||||
|
||||
# For each SparseTensor tuple, compare (ix, vals, shape)
|
||||
for out_st, truth_st, tf_st in zip(
|
||||
output_sparse_tensors, decode_truth, decoded_list):
|
||||
self.assertAllEqual(out_st[0], truth_st[0]) # ix
|
||||
self.assertAllEqual(out_st[1], truth_st[1]) # vals
|
||||
self.assertAllEqual(out_st[2], truth_st[2]) # shape
|
||||
# Compare the shapes of the components with the truth. The
|
||||
# `None` elements are not known statically.
|
||||
self.assertEqual([None, truth_st[0].shape[1]],
|
||||
tf_st.indices.get_shape().as_list())
|
||||
self.assertEqual([None], tf_st.values.get_shape().as_list())
|
||||
self.assertShapeEqual(truth_st[2], tf_st.shape)
|
||||
|
||||
# Make sure decoded probabilities match
|
||||
self.assertAllClose(output_log_probability, log_prob_truth, atol=1e-6)
|
||||
else:
|
||||
with self.assertRaisesOpError(expected_err_re):
|
||||
sess.run(decoded_unwrapped + [log_probability])
|
||||
|
||||
def testCTCGreedyDecoder(self):
|
||||
"""Test two batch entries - best path decoder."""
|
||||
max_time_steps = 6
|
||||
# depth == 4
|
||||
|
||||
seq_len_0 = 4
|
||||
input_prob_matrix_0 = np.asarray(
|
||||
[[1.0, 0.0, 0.0, 0.0], # t=0
|
||||
[0.0, 0.0, 0.4, 0.6], # t=1
|
||||
[0.0, 0.0, 0.4, 0.6], # t=2
|
||||
[0.0, 0.9, 0.1, 0.0], # t=3
|
||||
[0.0, 0.0, 0.0, 0.0], # t=4 (ignored)
|
||||
[0.0, 0.0, 0.0, 0.0]], # t=5 (ignored)
|
||||
dtype=np.float32)
|
||||
input_log_prob_matrix_0 = np.log(input_prob_matrix_0)
|
||||
|
||||
seq_len_1 = 5
|
||||
# dimensions are time x depth
|
||||
|
||||
input_prob_matrix_1 = np.asarray(
|
||||
[[0.1, 0.9, 0.0, 0.0], # t=0
|
||||
[0.0, 0.9, 0.1, 0.0], # t=1
|
||||
[0.0, 0.0, 0.1, 0.9], # t=2
|
||||
[0.0, 0.9, 0.1, 0.1], # t=3
|
||||
[0.9, 0.1, 0.0, 0.0], # t=4
|
||||
[0.0, 0.0, 0.0, 0.0]], # t=5 (ignored)
|
||||
dtype=np.float32)
|
||||
input_log_prob_matrix_1 = np.log(input_prob_matrix_1)
|
||||
|
||||
# len max_time_steps array of batch_size x depth matrices
|
||||
inputs = [np.vstack([input_log_prob_matrix_0[t, :],
|
||||
input_log_prob_matrix_1[t, :]])
|
||||
for t in range(max_time_steps)]
|
||||
|
||||
# batch_size length vector of sequence_lengths
|
||||
seq_lens = np.array([seq_len_0, seq_len_1], dtype=np.int32)
|
||||
|
||||
# batch_size length vector of negative log probabilities
|
||||
log_prob_truth = np.array([
|
||||
np.sum(-np.log([1.0, 0.6, 0.6, 0.9])),
|
||||
np.sum(-np.log([0.9, 0.9, 0.9, 0.9, 0.9]))
|
||||
], np.float32)[:, np.newaxis]
|
||||
|
||||
# decode_truth: one SparseTensor (ix, vals, shape)
|
||||
decode_truth = [
|
||||
(np.array([[0, 0], # batch 0, 2 outputs
|
||||
[0, 1],
|
||||
[1, 0], # batch 1, 3 outputs
|
||||
[1, 1],
|
||||
[1, 2]], dtype=np.int64),
|
||||
np.array([0, 1, # batch 0
|
||||
1, 1, 0], # batch 1
|
||||
dtype=np.int64),
|
||||
# shape is batch x max_decoded_length
|
||||
np.array([2, 3], dtype=np.int64)),
|
||||
]
|
||||
|
||||
self._testCTCDecoder(
|
||||
tf.contrib.ctc.ctc_greedy_decoder,
|
||||
inputs, seq_lens, log_prob_truth, decode_truth)
|
||||
|
||||
def testCTCDecoderBeamSearch(self):
|
||||
"""Test one batch, two beams - hibernating beam search."""
|
||||
# max_time_steps == 8
|
||||
depth = 6
|
||||
|
||||
seq_len_0 = 5
|
||||
input_prob_matrix_0 = np.asarray(
|
||||
[[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
|
||||
[0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
|
||||
[0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
|
||||
[0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
|
||||
[0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
|
||||
# Random entry added in at time=5
|
||||
[0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]],
|
||||
dtype=np.float32)
|
||||
# Add arbitrary offset - this is fine
|
||||
input_log_prob_matrix_0 = np.log(input_prob_matrix_0) + 2.0
|
||||
|
||||
# len max_time_steps array of batch_size x depth matrices
|
||||
inputs = ([input_log_prob_matrix_0[t, :][np.newaxis, :]
|
||||
for t in range(seq_len_0)] # Pad to max_time_steps = 8
|
||||
+ 2 * [np.zeros((1, depth), dtype=np.float32)])
|
||||
|
||||
# batch_size length vector of sequence_lengths
|
||||
seq_lens = np.array([seq_len_0], dtype=np.int32)
|
||||
|
||||
# batch_size length vector of negative log probabilities
|
||||
log_prob_truth = np.array([
|
||||
0.584855, # output beam 0
|
||||
0.389139 # output beam 1
|
||||
], np.float32)[np.newaxis, :]
|
||||
|
||||
# decode_truth: two SparseTensors, (ix, values, shape)
|
||||
decode_truth = [
|
||||
# beam 0, batch 0, two outputs decoded
|
||||
(np.array([[0, 0], [0, 1]], dtype=np.int64),
|
||||
np.array([1, 0], dtype=np.int64),
|
||||
np.array([1, 2], dtype=np.int64)),
|
||||
# beam 1, batch 0, three outputs decoded
|
||||
(np.array([[0, 0], [0, 1], [0, 2]], dtype=np.int64),
|
||||
np.array([0, 1, 0], dtype=np.int64),
|
||||
np.array([1, 3], dtype=np.int64)),
|
||||
]
|
||||
|
||||
self._testCTCDecoder(
|
||||
tf.contrib.ctc.ctc_beam_search_decoder,
|
||||
inputs, seq_lens, log_prob_truth,
|
||||
decode_truth,
|
||||
beam_width=2,
|
||||
top_paths=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
206
tensorflow/contrib/ctc/ctc_loss_op_test.py
Normal file
206
tensorflow/contrib/ctc/ctc_loss_op_test.py
Normal file
@ -0,0 +1,206 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for tensorflow.ctc_ops.ctc_decoder_ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def SimpleSparseTensorFrom(x):
|
||||
"""Create a very simple SparseTensor with dimensions (batch, time).
|
||||
|
||||
Args:
|
||||
x: a list of lists of type int
|
||||
|
||||
Returns:
|
||||
x_ix and x_val, the indices and values of the SparseTensor<2>.
|
||||
"""
|
||||
x_ix = []
|
||||
x_val = []
|
||||
for batch_i, batch in enumerate(x):
|
||||
for time, val in enumerate(batch):
|
||||
x_ix.append([batch_i, time])
|
||||
x_val.append(val)
|
||||
x_shape = [len(x), np.asarray(x_ix).max(0)[1]+1]
|
||||
x_ix = tf.constant(x_ix, tf.int64)
|
||||
x_val = tf.constant(x_val, tf.int32)
|
||||
x_shape = tf.constant(x_shape, tf.int64)
|
||||
|
||||
return tf.SparseTensor(x_ix, x_val, x_shape)
|
||||
|
||||
|
||||
class CTCLossTest(tf.test.TestCase):
|
||||
|
||||
def _testCTCLoss(self, inputs, seq_lens, labels,
|
||||
loss_truth, grad_truth, expected_err_re=None):
|
||||
self.assertEquals(len(inputs), len(grad_truth))
|
||||
|
||||
inputs_t = tf.constant(inputs)
|
||||
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
loss = tf.contrib.ctc.ctc_loss(inputs=inputs_t,
|
||||
labels=labels,
|
||||
sequence_length=seq_lens)
|
||||
grad = tf.gradients(loss, [inputs_t])[0]
|
||||
|
||||
self.assertShapeEqual(loss_truth, loss)
|
||||
self.assertShapeEqual(grad_truth, grad)
|
||||
|
||||
if expected_err_re is None:
|
||||
(tf_loss, tf_grad) = sess.run([loss, grad])
|
||||
self.assertAllClose(tf_loss, loss_truth, atol=1e-6)
|
||||
self.assertAllClose(tf_grad, grad_truth, atol=1e-6)
|
||||
else:
|
||||
with self.assertRaisesOpError(expected_err_re):
|
||||
sess.run([loss, grad])
|
||||
|
||||
def testBasic(self):
|
||||
"""Test two batch entries."""
|
||||
# Input and ground truth from Alex Graves' implementation.
|
||||
#
|
||||
#### Batch entry 0 #####
|
||||
# targets: 0 1 2 1 0
|
||||
# outputs:
|
||||
# 0 0.633766 0.221185 0.0917319 0.0129757 0.0142857 0.0260553
|
||||
# 1 0.111121 0.588392 0.278779 0.0055756 0.00569609 0.010436
|
||||
# 2 0.0357786 0.633813 0.321418 0.00249248 0.00272882 0.0037688
|
||||
# 3 0.0663296 0.643849 0.280111 0.00283995 0.0035545 0.00331533
|
||||
# 4 0.458235 0.396634 0.123377 0.00648837 0.00903441 0.00623107
|
||||
# alpha:
|
||||
# 0 -3.64753 -0.456075 -inf -inf -inf -inf -inf -inf -inf -inf -inf
|
||||
# 1 -inf -inf -inf -0.986437 -inf -inf -inf -inf -inf -inf -inf
|
||||
# 2 -inf -inf -inf -inf -inf -2.12145 -inf -inf -inf -inf -inf
|
||||
# 3 -inf -inf -inf -inf -inf -inf -inf -2.56174 -inf -inf -inf
|
||||
# 4 -inf -inf -inf -inf -inf -inf -inf -inf -inf -3.34211 -inf
|
||||
# beta:
|
||||
# 0 -inf -2.88604 -inf -inf -inf -inf -inf -inf -inf -inf -inf
|
||||
# 1 -inf -inf -inf -2.35568 -inf -inf -inf -inf -inf -inf -inf
|
||||
# 2 -inf -inf -inf -inf -inf -1.22066 -inf -inf -inf -inf -inf
|
||||
# 3 -inf -inf -inf -inf -inf -inf -inf -0.780373 -inf -inf -inf
|
||||
# 4 -inf -inf -inf -inf -inf -inf -inf -inf -inf 0 0
|
||||
# prob: -3.34211
|
||||
# outputDerivs:
|
||||
# 0 -0.366234 0.221185 0.0917319 0.0129757 0.0142857 0.0260553
|
||||
# 1 0.111121 -0.411608 0.278779 0.0055756 0.00569609 0.010436
|
||||
# 2 0.0357786 0.633813 -0.678582 0.00249248 0.00272882 0.0037688
|
||||
# 3 0.0663296 -0.356151 0.280111 0.00283995 0.0035545 0.00331533
|
||||
# 4 -0.541765 0.396634 0.123377 0.00648837 0.00903441 0.00623107
|
||||
#
|
||||
#### Batch entry 1 #####
|
||||
#
|
||||
# targets: 0 1 1 0
|
||||
# outputs:
|
||||
# 0 0.30176 0.28562 0.0831517 0.0862751 0.0816851 0.161508
|
||||
# 1 0.24082 0.397533 0.0557226 0.0546814 0.0557528 0.19549
|
||||
# 2 0.230246 0.450868 0.0389607 0.038309 0.0391602 0.202456
|
||||
# 3 0.280884 0.429522 0.0326593 0.0339046 0.0326856 0.190345
|
||||
# 4 0.423286 0.315517 0.0338439 0.0393744 0.0339315 0.154046
|
||||
# alpha:
|
||||
# 0 -1.8232 -1.19812 -inf -inf -inf -inf -inf -inf -inf
|
||||
# 1 -inf -2.19315 -2.83037 -2.1206 -inf -inf -inf -inf -inf
|
||||
# 2 -inf -inf -inf -2.03268 -3.71783 -inf -inf -inf -inf
|
||||
# 3 -inf -inf -inf -inf -inf -4.56292 -inf -inf -inf
|
||||
# 4 -inf -inf -inf -inf -inf -inf -inf -5.42262 -inf
|
||||
# beta:
|
||||
# 0 -inf -4.2245 -inf -inf -inf -inf -inf -inf -inf
|
||||
# 1 -inf -inf -inf -3.30202 -inf -inf -inf -inf -inf
|
||||
# 2 -inf -inf -inf -inf -1.70479 -0.856738 -inf -inf -inf
|
||||
# 3 -inf -inf -inf -inf -inf -0.859706 -0.859706 -0.549337 -inf
|
||||
# 4 -inf -inf -inf -inf -inf -inf -inf 0 0
|
||||
# prob: -5.42262
|
||||
# outputDerivs:
|
||||
# 0 -0.69824 0.28562 0.0831517 0.0862751 0.0816851 0.161508
|
||||
# 1 0.24082 -0.602467 0.0557226 0.0546814 0.0557528 0.19549
|
||||
# 2 0.230246 0.450868 0.0389607 0.038309 0.0391602 -0.797544
|
||||
# 3 0.280884 -0.570478 0.0326593 0.0339046 0.0326856 0.190345
|
||||
# 4 -0.576714 0.315517 0.0338439 0.0393744 0.0339315 0.154046
|
||||
|
||||
# max_time_steps == 7
|
||||
depth = 6
|
||||
|
||||
# seq_len_0 == 5
|
||||
targets_0 = [0, 1, 2, 1, 0]
|
||||
loss_log_prob_0 = -3.34211
|
||||
# dimensions are time x depth
|
||||
input_prob_matrix_0 = np.asarray(
|
||||
[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
|
||||
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
|
||||
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
|
||||
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
|
||||
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
|
||||
dtype=np.float32)
|
||||
input_log_prob_matrix_0 = np.log(input_prob_matrix_0)
|
||||
gradient_log_prob_0 = np.asarray(
|
||||
[[-0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
|
||||
[0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436],
|
||||
[0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688],
|
||||
[0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533],
|
||||
[-0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
|
||||
dtype=np.float32)
|
||||
|
||||
# seq_len_1 == 5
|
||||
targets_1 = [0, 1, 1, 0]
|
||||
loss_log_prob_1 = -5.42262
|
||||
# dimensions are time x depth
|
||||
|
||||
input_prob_matrix_1 = np.asarray(
|
||||
[[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
|
||||
[0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549],
|
||||
[0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456],
|
||||
[0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345],
|
||||
[0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]],
|
||||
dtype=np.float32)
|
||||
input_log_prob_matrix_1 = np.log(input_prob_matrix_1)
|
||||
gradient_log_prob_1 = np.asarray(
|
||||
[[-0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
|
||||
[0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549],
|
||||
[0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544],
|
||||
[0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345],
|
||||
[-0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]],
|
||||
dtype=np.float32)
|
||||
|
||||
# len max_time_steps array of 2 x depth matrices
|
||||
inputs = [np.vstack([input_log_prob_matrix_0[t, :],
|
||||
input_log_prob_matrix_1[t, :]])
|
||||
for t in range(5)] + 2 * [np.nan*np.ones((2, depth), np.float32)]
|
||||
|
||||
# convert inputs into [max_time x batch_size x depth tensor] Tensor
|
||||
inputs = np.asarray(inputs)
|
||||
|
||||
# len batch_size array of label vectors
|
||||
labels = SimpleSparseTensorFrom([targets_0, targets_1])
|
||||
|
||||
# batch_size length vector of sequence_lengths
|
||||
seq_lens = np.array([5, 5], dtype=np.int32)
|
||||
|
||||
# output: batch_size length vector of negative log probabilities
|
||||
loss_truth = np.array([-loss_log_prob_0, -loss_log_prob_1], np.float32)
|
||||
|
||||
# output: len max_time_steps array of 2 x depth matrices
|
||||
grad_truth = [np.vstack([gradient_log_prob_0[t, :],
|
||||
gradient_log_prob_1[t, :]])
|
||||
for t in range(5)] + 2 * [np.zeros((2, depth), np.float32)]
|
||||
|
||||
# convert grad_truth into [max_time x batch_size x depth] Tensor
|
||||
grad_truth = np.asarray(grad_truth)
|
||||
|
||||
self._testCTCLoss(inputs, seq_lens, labels, loss_truth, grad_truth)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
254
tensorflow/contrib/ctc/ctc_ops.py
Normal file
254
tensorflow/contrib/ctc/ctc_ops.py
Normal file
@ -0,0 +1,254 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# pylint: disable=unused-import
|
||||
"""CTC (Connectionist Temporal Classification) Operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
||||
from tensorflow.python.ops import gen_ctc_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.nn_grad import _BroadcastMul
|
||||
|
||||
|
||||
# NOTE(ebrevdo): We redefine CTCLoss from gen_ctc_ops to only return
|
||||
# the first output. The second output is only used for the gradient.
|
||||
# pylint: disable=protected-access, invalid-name
|
||||
def ctc_loss(inputs, labels, sequence_length,
|
||||
preprocess_collapse_repeated=False, ctc_merge_repeated=True):
|
||||
"""Computes the CTC (Connectionist Temporal Classification) Loss.
|
||||
|
||||
Requires:
|
||||
```sequence_length(b) <= time for all b
|
||||
|
||||
max(labels.indices(labels.indices[:, 1] == b, 2))
|
||||
<= sequence_length(b) for all b.```
|
||||
|
||||
If ctc_merge_repeated is set False, then *during* CTC calculation
|
||||
repeated non-blank labels will not be merged and are interpreted
|
||||
as individual labels. This is a simplified version of CTC.
|
||||
|
||||
|
||||
Args:
|
||||
inputs: 3-D `float` `Tensor` sized
|
||||
`[max_time x batch_size x num_classes]`. The logits.
|
||||
labels: An `int32` `SparseTensor`.
|
||||
`labels.indices[i, :] == [b, t]` means `labels.values[i]` stores
|
||||
the id for (batch b, time t). See `core/ops/ctc_ops.cc` for more details.
|
||||
sequence_length: 1-D `int32` vector, size `[batch_size]`.
|
||||
The sequence lengths.
|
||||
preprocess_collapse_repeated: Boolean. Default: False.
|
||||
If True, repeated labels are collapsed prior to the CTC calculation.
|
||||
ctc_merge_repeated: Boolean. Default: True.
|
||||
|
||||
|
||||
Returns:
|
||||
A 1-D `float` `Tensor`, size `[batch]`, containing logits.
|
||||
|
||||
|
||||
Raises:
|
||||
TypeError: if labels is not a `SparseTensor`.
|
||||
"""
|
||||
# The second, third, etc output tensors contain the gradients. We use it in
|
||||
# _CTCLossGrad() below.
|
||||
if not isinstance(labels, ops.SparseTensor):
|
||||
raise TypeError("Expected labels to be a SparseTensor")
|
||||
|
||||
loss, _ = gen_ctc_ops._ctc_loss(
|
||||
inputs,
|
||||
labels.indices,
|
||||
labels.values,
|
||||
sequence_length,
|
||||
preprocess_collapse_repeated=preprocess_collapse_repeated,
|
||||
ctc_merge_repeated=ctc_merge_repeated)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@ops.RegisterGradient("CTCLoss")
|
||||
def _CTCLossGrad(op, grad_loss, _):
|
||||
"""The derivative provided by CTC Loss.
|
||||
|
||||
Args:
|
||||
op: the CTCLoss op.
|
||||
grad_loss: The backprop for cost.
|
||||
|
||||
Returns:
|
||||
The CTC Loss gradient.
|
||||
"""
|
||||
# Outputs are: loss, grad
|
||||
grad = op.outputs[1]
|
||||
# Return gradient for inputs and None for
|
||||
# labels_indices, labels_values and sequence_length
|
||||
return [_BroadcastMul(grad_loss, grad), None, None, None]
|
||||
|
||||
|
||||
@ops.RegisterShape("CTCLoss")
|
||||
def _CTCLossShape(op):
|
||||
"""Shape function for the CTCLoss op."""
|
||||
# inputs, label_indices, label_values, sequence_length
|
||||
inputs_shape = op.inputs[0].get_shape().with_rank(3)
|
||||
sequence_length_shape = op.inputs[3].get_shape().with_rank(1)
|
||||
# merge batch_size
|
||||
sequence_length_shape[0].merge_with(inputs_shape[1])
|
||||
inputs_shape[1].merge_with(sequence_length_shape[0])
|
||||
batch_size = inputs_shape[1]
|
||||
labels_index_shape = op.inputs[1].get_shape().with_rank(2)
|
||||
labels_value_shape = op.inputs[2].get_shape().with_rank(1)
|
||||
labels_value_shape[0].merge_with(labels_index_shape[0])
|
||||
# loss, gradient
|
||||
return [tensor_shape.vector(batch_size), inputs_shape]
|
||||
|
||||
|
||||
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
||||
"""Performs greedy decoding on the logits given in input (best path).
|
||||
|
||||
Note: Regardless of the value of merge_repeated, if the maximum index of a
|
||||
given time and batch corresponds to the blank index `(num_classes - 1)`, no
|
||||
new element is emitted.
|
||||
|
||||
If merge_repeated is `True`, merge repeated classes in output.
|
||||
This means that if consecutive logits' maximum indices are the same,
|
||||
only the first of these is emitted. Labeling the blank '*', the sequence
|
||||
"A B B * B B" becomes "A B" if `merge_repeated = True` and "A B B B B"
|
||||
if `merge_repeated = False`.
|
||||
|
||||
|
||||
Args:
|
||||
inputs: 3-D `float` `Tensor` sized
|
||||
`[max_time x batch_size x num_classes]`. The logits.
|
||||
sequence_length: 1-D `int32` vector containing sequence lengths,
|
||||
having size `[batch_size]`.
|
||||
merge_repeated: Boolean. Default: True.
|
||||
|
||||
|
||||
Returns:
|
||||
A tuple `(decoded, log_probabilities)` where
|
||||
decoded: A single-element list. `decoded[0]`
|
||||
is an `SparseTensor` containing the decoded outputs s.t.:
|
||||
`decoded.indices`: Indices matrix `(total_decoded_outputs x 2)`.
|
||||
The rows store: `[batch, time]`.
|
||||
`decoded.values`: Values vector, size `(total_decoded_outputs)`.
|
||||
The vector stores the decoded classes.
|
||||
`decoded.shape`: Shape vector, size `(2)`.
|
||||
The shape values are: `[batch_size, max_decoded_length]`
|
||||
log_probability: A `float` matrix `(batch_size x 1)` containing sequence
|
||||
log-probabilities.
|
||||
"""
|
||||
outputs = gen_ctc_ops._ctc_greedy_decoder(
|
||||
inputs, sequence_length, merge_repeated=merge_repeated)
|
||||
(decoded_ix, decoded_val, decoded_shape, log_probabilities) = outputs
|
||||
return ([ops.SparseTensor(decoded_ix, decoded_val, decoded_shape)],
|
||||
log_probabilities)
|
||||
|
||||
|
||||
@ops.RegisterShape("CTCGreedyDecoder")
|
||||
def _CTCGreedyDecoderShape(op):
|
||||
"""Shape function for the CTCGreedyDecoder op."""
|
||||
inputs_shape = op.inputs[0].get_shape().with_rank(3)
|
||||
sequence_length_shape = op.inputs[1].get_shape().with_rank(1)
|
||||
# merge batch_size
|
||||
sequence_length_shape[0].merge_with(inputs_shape[1])
|
||||
inputs_shape[1].merge_with(sequence_length_shape[0])
|
||||
batch_size = inputs_shape[1]
|
||||
# decoded_indices, decoded_values, decoded_shape, log_probability
|
||||
return [tensor_shape.matrix(None, 2),
|
||||
tensor_shape.vector(None),
|
||||
tensor_shape.vector(2),
|
||||
tensor_shape.matrix(batch_size, 1)]
|
||||
|
||||
|
||||
def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
|
||||
top_paths=1, merge_repeated=True):
|
||||
"""Performs beam search decoding on the logits given in input.
|
||||
|
||||
If merge_repeated is `True`, merge repeated classes in output.
|
||||
This means that if consecutive entries in a beam are the same,
|
||||
only the first of these is emitted. That is, when the top path
|
||||
is "A B B B B", "A B" is returned if `merge_repeated = True`
|
||||
but "A B B B B" is returned if `merge_repeated = False`.
|
||||
|
||||
|
||||
Args:
|
||||
inputs: 3-D `float` `Tensor`, size
|
||||
`[max_time x batch_size x num_classes]`. The logits.
|
||||
sequence_length: 1-D `int32` vector containing sequence lengths,
|
||||
having size `[batch_size]`.
|
||||
beam_width: An int scalar >= 0 (beam search beam width).
|
||||
top_paths: An int scalar >= 0, <= beam_width (controls output size).
|
||||
merge_repeated: Boolean. Default: True.
|
||||
|
||||
|
||||
Returns:
|
||||
A tuple `(decoded, log_probabilities)` where
|
||||
decoded: A list of length top_paths, where `decoded[j]`
|
||||
is a `SparseTensor` containing the decoded outputs:
|
||||
`decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
|
||||
The rows store: [batch, time].
|
||||
`decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
|
||||
The vector stores the decoded classes for beam j.
|
||||
`decoded[j].shape`: Shape vector, size `(2)`.
|
||||
The shape values are: `[batch_size, max_decoded_length[j]]`.
|
||||
log_probability: A `float` matrix `(batch_size x top_paths)` containing
|
||||
sequence log-probabilities.
|
||||
"""
|
||||
|
||||
decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
|
||||
gen_ctc_ops._ctc_beam_search_decoder(
|
||||
inputs, sequence_length, beam_width=beam_width, top_paths=top_paths,
|
||||
merge_repeated=merge_repeated))
|
||||
|
||||
return (
|
||||
[ops.SparseTensor(ix, val, shape) for (ix, val, shape)
|
||||
in zip(decoded_ixs, decoded_vals, decoded_shapes)],
|
||||
log_probabilities)
|
||||
|
||||
|
||||
@ops.RegisterShape("CTCBeamSearchDecoder")
|
||||
def _CTCBeamSearchDecoderShape(op):
|
||||
"""Shape function for the CTCBeamSearchDecoder op."""
|
||||
inputs_shape = op.inputs[0].get_shape().with_rank(3)
|
||||
sequence_length_shape = op.inputs[1].get_shape().with_rank(1)
|
||||
# merge batch size
|
||||
sequence_length_shape[0].merge_with(inputs_shape[1])
|
||||
inputs_shape[1].merge_with(sequence_length_shape[0])
|
||||
batch_size = inputs_shape[1]
|
||||
top_paths = op.get_attr("top_paths")
|
||||
|
||||
# first the decoded indices
|
||||
output_shapes = [tensor_shape.matrix(None, 2) for _ in range(top_paths)]
|
||||
# next the decoded values
|
||||
output_shapes.extend([tensor_shape.vector(None) for _ in range(top_paths)])
|
||||
# the shapes of the decoded values
|
||||
output_shapes.extend([tensor_shape.vector(2)] * top_paths)
|
||||
# the log_probability matrix
|
||||
output_shapes.append(tensor_shape.matrix(batch_size, top_paths))
|
||||
return output_shapes
|
||||
|
||||
|
||||
ops.NoGradient("CTCGreedyDecoder")
|
||||
|
||||
|
||||
ops.NoGradient("CTCBeamSearchDecoder")
|
@ -55,7 +55,6 @@ load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_tests")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gpu_kernel_library")
|
||||
|
||||
# For platform specific build config
|
||||
load(
|
||||
@ -303,6 +302,7 @@ tf_gen_op_libs(
|
||||
"attention_ops",
|
||||
"candidate_sampling_ops",
|
||||
"control_flow_ops",
|
||||
"ctc_ops",
|
||||
"data_flow_ops",
|
||||
"function_ops",
|
||||
"functional_ops",
|
||||
@ -344,6 +344,7 @@ cc_library(
|
||||
":attention_ops_op_lib",
|
||||
":candidate_sampling_ops_op_lib",
|
||||
":control_flow_ops_op_lib",
|
||||
":ctc_ops_op_lib",
|
||||
":data_flow_ops_op_lib",
|
||||
":function_ops_op_lib",
|
||||
":functional_ops_op_lib",
|
||||
@ -494,6 +495,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:array",
|
||||
"//tensorflow/core/kernels:candidate_sampler_ops",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:ctc_ops",
|
||||
"//tensorflow/core/kernels:data_flow",
|
||||
"//tensorflow/core/kernels:fact_op",
|
||||
"//tensorflow/core/kernels:image",
|
||||
@ -592,7 +594,10 @@ filegroup(
|
||||
# sources.
|
||||
filegroup(
|
||||
name = "android_srcs",
|
||||
srcs = ["//tensorflow/core/kernels:android_srcs"] + glob(
|
||||
srcs = [
|
||||
"//tensorflow/core/kernels:android_srcs",
|
||||
"//tensorflow/core/util/ctc:android_srcs",
|
||||
] + glob(
|
||||
[
|
||||
"client/**/*.cc",
|
||||
"common_runtime/**/*.h",
|
||||
|
@ -335,6 +335,19 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "ctc_ops",
|
||||
prefix = "ctc",
|
||||
deps = [
|
||||
":ops_util",
|
||||
"//tensorflow/core:ctc_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
|
||||
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "control_flow_ops_test",
|
||||
deps = [
|
||||
@ -1081,6 +1094,7 @@ filegroup(
|
||||
filegroup(
|
||||
name = "android_extended_ops_group2",
|
||||
srcs = [
|
||||
"ctc_decoder_ops.cc",
|
||||
"dynamic_stitch_op.cc",
|
||||
"in_topk_op.cc",
|
||||
"lrn_op.cc",
|
||||
|
323
tensorflow/core/kernels/ctc_decoder_ops.cc
Normal file
323
tensorflow/core/kernels/ctc_decoder_ops.cc
Normal file
@ -0,0 +1,323 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/ctc_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
|
||||
int* c) {
|
||||
*c = 0;
|
||||
CHECK_LT(0, m.dimension(1));
|
||||
float p = m(r, 0);
|
||||
for (int i = 1; i < m.dimension(1); ++i) {
|
||||
if (m(r, i) > p) {
|
||||
p = m(r, i);
|
||||
*c = i;
|
||||
}
|
||||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
class CTCDecodeHelper {
|
||||
public:
|
||||
CTCDecodeHelper() : top_paths_(1) {}
|
||||
|
||||
inline int GetTopPaths() const { return top_paths_; }
|
||||
void SetTopPaths(int tp) { top_paths_ = tp; }
|
||||
|
||||
Status ValidateInputsGenerateOutputs(
|
||||
OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
|
||||
Tensor** log_prob, OpOutputList* decoded_indices,
|
||||
OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
|
||||
Status status = ctx->input("inputs", inputs);
|
||||
if (!status.ok()) return status;
|
||||
status = ctx->input("sequence_length", seq_len);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
const TensorShape& inputs_shape = (*inputs)->shape();
|
||||
|
||||
if (inputs_shape.dims() != 3) {
|
||||
return errors::InvalidArgument("inputs is not a 3-Tensor");
|
||||
}
|
||||
|
||||
const int64 max_time = inputs_shape.dim_size(0);
|
||||
const int64 batch_size = inputs_shape.dim_size(1);
|
||||
|
||||
if (max_time == 0) {
|
||||
return errors::InvalidArgument("max_time is 0");
|
||||
}
|
||||
if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
|
||||
return errors::InvalidArgument("sequence_length is not a vector");
|
||||
}
|
||||
|
||||
if (!(batch_size == (*seq_len)->dim_size(0))) {
|
||||
return errors::FailedPrecondition(
|
||||
"len(sequence_length) != batch_size. ", "len(sequence_length): ",
|
||||
(*seq_len)->dim_size(0), " batch_size: ", batch_size);
|
||||
}
|
||||
|
||||
auto seq_len_t = (*seq_len)->vec<int32>();
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
if (!(seq_len_t(b) <= max_time)) {
|
||||
return errors::FailedPrecondition("sequence_length(", b, ") <= ",
|
||||
max_time);
|
||||
}
|
||||
}
|
||||
|
||||
Status s = ctx->allocate_output(
|
||||
"log_probability", TensorShape({batch_size, top_paths_}), log_prob);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
s = ctx->output_list("decoded_indices", decoded_indices);
|
||||
if (!s.ok()) return s;
|
||||
s = ctx->output_list("decoded_values", decoded_values);
|
||||
if (!s.ok()) return s;
|
||||
s = ctx->output_list("decoded_shape", decoded_shape);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
|
||||
Status StoreAllDecodedSequences(
|
||||
const std::vector<std::vector<std::vector<int> > >& sequences,
|
||||
OpOutputList* decoded_indices, OpOutputList* decoded_values,
|
||||
OpOutputList* decoded_shape) const {
|
||||
// Calculate the total number of entries for each path
|
||||
const int batch_size = sequences.size();
|
||||
std::vector<int64> num_entries(top_paths_, 0);
|
||||
|
||||
// Calculate num_entries per path
|
||||
for (const auto& batch_s : sequences) {
|
||||
CHECK_EQ(batch_s.size(), top_paths_);
|
||||
for (int p = 0; p < top_paths_; ++p) {
|
||||
num_entries[p] += batch_s[p].size();
|
||||
}
|
||||
}
|
||||
|
||||
for (int p = 0; p < top_paths_; ++p) {
|
||||
Tensor* p_indices = nullptr;
|
||||
Tensor* p_values = nullptr;
|
||||
Tensor* p_shape = nullptr;
|
||||
|
||||
const int64 p_num = num_entries[p];
|
||||
|
||||
Status s =
|
||||
decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
|
||||
if (!s.ok()) return s;
|
||||
s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
|
||||
if (!s.ok()) return s;
|
||||
s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
auto indices_t = p_indices->matrix<int64>();
|
||||
auto values_t = p_values->vec<int64>();
|
||||
auto shape_t = p_shape->vec<int64>();
|
||||
|
||||
int64 max_decoded = 0;
|
||||
int64 offset = 0;
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
auto& p_batch = sequences[b][p];
|
||||
int64 num_decoded = p_batch.size();
|
||||
max_decoded = std::max(max_decoded, num_decoded);
|
||||
std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
|
||||
for (int t = 0; t < num_decoded; ++t, ++offset) {
|
||||
indices_t(offset, 0) = b;
|
||||
indices_t(offset, 1) = t;
|
||||
}
|
||||
}
|
||||
|
||||
shape_t(0) = batch_size;
|
||||
shape_t(1) = max_decoded;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int top_paths_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
|
||||
};
|
||||
|
||||
class CTCGreedyDecoderOp : public OpKernel {
|
||||
public:
|
||||
explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* inputs;
|
||||
const Tensor* seq_len;
|
||||
Tensor* log_prob = nullptr;
|
||||
OpOutputList decoded_indices;
|
||||
OpOutputList decoded_values;
|
||||
OpOutputList decoded_shape;
|
||||
OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
|
||||
ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
|
||||
&decoded_values, &decoded_shape));
|
||||
|
||||
const TensorShape& inputs_shape = inputs->shape();
|
||||
|
||||
std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
|
||||
const int64 max_time = inputs_shape.dim_size(0);
|
||||
const int64 batch_size = inputs_shape.dim_size(1);
|
||||
const int64 num_classes = inputs_shape.dim_size(2);
|
||||
|
||||
auto inputs_t = inputs->tensor<float, 3>();
|
||||
|
||||
for (std::size_t t = 0; t < max_time; ++t) {
|
||||
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
|
||||
batch_size, num_classes);
|
||||
}
|
||||
auto seq_len_t = seq_len->vec<int32>();
|
||||
auto log_prob_t = log_prob->matrix<float>();
|
||||
|
||||
log_prob_t.setZero();
|
||||
|
||||
// Assumption: the blank index is num_classes - 1
|
||||
int blank_index = num_classes - 1;
|
||||
|
||||
// Perform best path decoding
|
||||
std::vector<std::vector<std::vector<int> > > sequences(batch_size);
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
sequences[b].resize(1);
|
||||
auto& sequence = sequences[b][0];
|
||||
int prev_indices = -1;
|
||||
for (int t = 0; t < seq_len_t(b); ++t) {
|
||||
int max_class_indices;
|
||||
log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
|
||||
if (max_class_indices != blank_index &&
|
||||
!(merge_repeated_ && max_class_indices == prev_indices)) {
|
||||
sequence.push_back(max_class_indices);
|
||||
}
|
||||
prev_indices = max_class_indices;
|
||||
}
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx, decode_helper_.StoreAllDecodedSequences(
|
||||
sequences, &decoded_indices, &decoded_values, &decoded_shape));
|
||||
}
|
||||
|
||||
private:
|
||||
CTCDecodeHelper decode_helper_;
|
||||
bool merge_repeated_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU),
|
||||
CTCGreedyDecoderOp);
|
||||
|
||||
// CTC beam search
|
||||
class CTCBeamSearchDecoderOp : public OpKernel {
|
||||
public:
|
||||
explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
|
||||
int top_paths;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
|
||||
decode_helper_.SetTopPaths(top_paths);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* inputs;
|
||||
const Tensor* seq_len;
|
||||
Tensor* log_prob = nullptr;
|
||||
OpOutputList decoded_indices;
|
||||
OpOutputList decoded_values;
|
||||
OpOutputList decoded_shape;
|
||||
OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
|
||||
ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
|
||||
&decoded_values, &decoded_shape));
|
||||
|
||||
auto inputs_t = inputs->tensor<float, 3>();
|
||||
auto seq_len_t = seq_len->vec<int32>();
|
||||
auto log_prob_t = log_prob->matrix<float>();
|
||||
|
||||
const TensorShape& inputs_shape = inputs->shape();
|
||||
|
||||
const int64 max_time = inputs_shape.dim_size(0);
|
||||
const int64 batch_size = inputs_shape.dim_size(1);
|
||||
const int64 num_classes = inputs_shape.dim_size(2);
|
||||
|
||||
log_prob_t.setZero();
|
||||
|
||||
std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
|
||||
|
||||
for (std::size_t t = 0; t < max_time; ++t) {
|
||||
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
|
||||
batch_size, num_classes);
|
||||
}
|
||||
|
||||
ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_);
|
||||
Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
|
||||
auto input_chip_t = input_chip.flat<float>();
|
||||
|
||||
std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
|
||||
std::vector<float> log_probs;
|
||||
|
||||
// Assumption: the blank index is num_classes - 1
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
auto& best_paths_b = best_paths[b];
|
||||
best_paths_b.resize(decode_helper_.GetTopPaths());
|
||||
for (int t = 0; t < seq_len_t(b); ++t) {
|
||||
input_chip_t = input_list_t[t].chip(b, 0);
|
||||
auto input_bi =
|
||||
Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
|
||||
beam_search.Step(input_bi);
|
||||
}
|
||||
beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
|
||||
&log_probs, merge_repeated_);
|
||||
beam_search.Reset();
|
||||
|
||||
for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
|
||||
log_prob_t(b, bp) = log_probs[bp];
|
||||
}
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
|
||||
best_paths, &decoded_indices, &decoded_values,
|
||||
&decoded_shape));
|
||||
}
|
||||
|
||||
private:
|
||||
CTCDecodeHelper decode_helper_;
|
||||
bool merge_repeated_;
|
||||
int beam_width_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU),
|
||||
CTCBeamSearchDecoderOp);
|
||||
|
||||
} // end namespace tensorflow
|
157
tensorflow/core/kernels/ctc_loss_op.cc
Normal file
157
tensorflow/core/kernels/ctc_loss_op.cc
Normal file
@ -0,0 +1,157 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/ctc_ops.cc.
|
||||
|
||||
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
class CTCLossOp : public OpKernel {
|
||||
typedef Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor> >
|
||||
InputMap;
|
||||
typedef Eigen::Map<
|
||||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> >
|
||||
OutputMap;
|
||||
|
||||
public:
|
||||
explicit CTCLossOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
|
||||
&preprocess_collapse_repeated_));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* inputs;
|
||||
const Tensor* labels_indices;
|
||||
const Tensor* labels_values;
|
||||
const Tensor* seq_len;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
|
||||
|
||||
OP_REQUIRES(ctx, inputs->shape().dims() == 3,
|
||||
errors::InvalidArgument("inputs is not a 3-Tensor"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
|
||||
errors::InvalidArgument("sequence_length is not a vector"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
|
||||
errors::InvalidArgument("labels_indices is not a matrix"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
|
||||
errors::InvalidArgument("labels_values is not a vector"));
|
||||
|
||||
const TensorShape& inputs_shape = inputs->shape();
|
||||
const int64 max_time = inputs_shape.dim_size(0);
|
||||
const int64 batch_size = inputs_shape.dim_size(1);
|
||||
const int64 num_classes = inputs_shape.dim_size(2);
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, batch_size == seq_len->dim_size(0),
|
||||
errors::InvalidArgument("len(sequence_length) != batch_size. ",
|
||||
"len(sequence_length): ", seq_len->dim_size(0),
|
||||
" batch_size: ", batch_size));
|
||||
auto seq_len_t = seq_len->vec<int32>();
|
||||
|
||||
OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"labels_indices and labels_values must contain the "
|
||||
"same number of rows, but saw shapes: ",
|
||||
labels_indices->shape().DebugString(), " vs. ",
|
||||
labels_values->shape().DebugString()));
|
||||
|
||||
TensorShape labels_shape({batch_size, max_time});
|
||||
std::vector<int64> order{0, 1};
|
||||
sparse::SparseTensor labels_sp(*labels_indices, *labels_values,
|
||||
labels_shape, order);
|
||||
|
||||
Status labels_sp_valid = labels_sp.IndicesValid();
|
||||
OP_REQUIRES(ctx, labels_sp_valid.ok(),
|
||||
errors::InvalidArgument("label SparseTensor is not valid: ",
|
||||
labels_sp_valid.error_message()));
|
||||
|
||||
ctc::CTCLossCalculator::LabelSequences labels_t(batch_size);
|
||||
for (const auto& g : labels_sp.group({0})) { // iterate by batch
|
||||
const int batch_indices = g.group()[0];
|
||||
OP_REQUIRES(ctx, batch_indices >= 0 && batch_indices < batch_size,
|
||||
errors::InvalidArgument("labels batch index must be between ",
|
||||
0, " and ", batch_size, " but saw: ",
|
||||
batch_indices));
|
||||
|
||||
auto values = g.values<int32>();
|
||||
std::vector<int>* b_values = &labels_t[batch_indices];
|
||||
b_values->resize(values.size());
|
||||
for (int i = 0; i < values.size(); ++i) (*b_values)[i] = values(i);
|
||||
}
|
||||
|
||||
OP_REQUIRES(ctx, static_cast<size_t>(batch_size) == labels_t.size(),
|
||||
errors::InvalidArgument("len(labels) != batch_size. ",
|
||||
"len(labels): ", labels_t.size(),
|
||||
" batch_size: ", batch_size));
|
||||
|
||||
for (int64 b = 0; b < batch_size; ++b) {
|
||||
OP_REQUIRES(
|
||||
ctx, seq_len_t(b) <= max_time,
|
||||
errors::InvalidArgument("sequence_length(", b, ") <= ", max_time));
|
||||
}
|
||||
|
||||
Tensor* loss = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
|
||||
auto loss_t = loss->vec<float>();
|
||||
|
||||
Tensor* gradient;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output("gradient", inputs_shape, &gradient));
|
||||
auto gradient_t = gradient->tensor<float, 3>();
|
||||
auto inputs_t = inputs->tensor<float, 3>();
|
||||
std::vector<OutputMap> gradient_list_t;
|
||||
std::vector<InputMap> input_list_t;
|
||||
|
||||
for (std::size_t t = 0; t < max_time; ++t) {
|
||||
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
|
||||
batch_size, num_classes);
|
||||
gradient_list_t.emplace_back(
|
||||
gradient_t.data() + t * batch_size * num_classes, batch_size,
|
||||
num_classes);
|
||||
}
|
||||
|
||||
gradient_t.setZero();
|
||||
|
||||
// Assumption: the blank index is num_classes - 1
|
||||
ctc::CTCLossCalculator ctc_loss_calculator(num_classes - 1, 0);
|
||||
OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
|
||||
seq_len_t, labels_t, input_list_t,
|
||||
preprocess_collapse_repeated_, ctc_merge_repeated_,
|
||||
&loss_t, &gradient_list_t));
|
||||
}
|
||||
|
||||
private:
|
||||
bool preprocess_collapse_repeated_;
|
||||
bool ctc_merge_repeated_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CTCLoss").Device(DEVICE_CPU), CTCLossOp);
|
||||
|
||||
} // end namespace tensorflow
|
123
tensorflow/core/ops/ctc_ops.cc
Normal file
123
tensorflow/core/ops/ctc_ops.cc
Normal file
@ -0,0 +1,123 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// CTC is Connectionist Temporal Classification. See util/ctc/ for details.
|
||||
|
||||
REGISTER_OP("CTCLoss")
|
||||
.Input("inputs: float")
|
||||
.Input("labels_indices: int64")
|
||||
.Input("labels_values: int32")
|
||||
.Input("sequence_length: int32")
|
||||
.Attr("preprocess_collapse_repeated: bool = false")
|
||||
.Attr("ctc_merge_repeated: bool = true")
|
||||
.Output("loss: float")
|
||||
.Output("gradient: float")
|
||||
.Doc(R"doc(
|
||||
Calculates the CTC Loss (log probability) for each batch entry. Also calculates
|
||||
the gradient. This class performs the softmax operation for you, so inputs
|
||||
should be e.g. linear projections of outputs by an LSTM.
|
||||
|
||||
inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits.
|
||||
labels_indices: The indices of a `SparseTensor<int32, 2>`.
|
||||
`labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for
|
||||
`(batch b, time t)`.
|
||||
labels_values: The values (labels) associated with the given batch and time.
|
||||
sequence_length: A vector containing sequence lengths (batch).
|
||||
preprocess_collapse_repeated: Scalar, if true then repeated labels are
|
||||
collapsed prior to the CTC calculation.
|
||||
ctc_merge_repeated: Scalar. If set to false, *during* CTC calculation
|
||||
repeated non-blank labels will not be merged and are interpreted as
|
||||
individual labels. This is a simplified version of CTC.
|
||||
loss: A vector (batch) containing log-probabilities.
|
||||
gradient: The gradient of `loss`. 3-D, shape:
|
||||
`(max_time x batch_size x num_classes)`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("CTCGreedyDecoder")
|
||||
.Input("inputs: float")
|
||||
.Input("sequence_length: int32")
|
||||
.Attr("merge_repeated: bool = false")
|
||||
.Output("decoded_indices: int64")
|
||||
.Output("decoded_values: int64")
|
||||
.Output("decoded_shape: int64")
|
||||
.Output("log_probability: float")
|
||||
.Doc(R"doc(
|
||||
Performs greedy decoding on the logits given in inputs.
|
||||
|
||||
A note about the attribute merge_repeated: if enabled, when
|
||||
consecutive logits' maximum indices are the same, only the first of
|
||||
these is emitted. Labeling the blank '*', the sequence "A B B * B B"
|
||||
becomes "A B" if merge_repeated = True and "A B B B B" if
|
||||
merge_repeated = False.
|
||||
|
||||
Regardless of the value of merge_repeated, if the maximum index of a given
|
||||
time and batch corresponds to the blank, index `(num_classes - 1)`, no new
|
||||
element is emitted.
|
||||
|
||||
inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits.
|
||||
sequence_length: A vector containing sequence lengths, size `(batch_size)`.
|
||||
merge_repeated: If True, merge repeated classes in output.
|
||||
decoded_indices: Indices matrix, size `(total_decoded_outputs x 2)`,
|
||||
of a `SparseTensor<int64, 2>`. The rows store: [batch, time].
|
||||
decoded_values: Values vector, size: `(total_decoded_outputs)`,
|
||||
of a `SparseTensor<int64, 2>`. The vector stores the decoded classes.
|
||||
decoded_shape: Shape vector, size `(2)`, of the decoded SparseTensor.
|
||||
Values are: `[batch_size, max_decoded_length]`.
|
||||
log_probability: Matrix, size `(batch_size x 1)`, containing sequence
|
||||
log-probabilities.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("CTCBeamSearchDecoder")
|
||||
.Input("inputs: float")
|
||||
.Input("sequence_length: int32")
|
||||
.Attr("beam_width: int >= 1")
|
||||
.Attr("top_paths: int >= 1")
|
||||
.Attr("merge_repeated: bool = true")
|
||||
.Output("decoded_indices: top_paths * int64")
|
||||
.Output("decoded_values: top_paths * int64")
|
||||
.Output("decoded_shape: top_paths * int64")
|
||||
.Output("log_probability: float")
|
||||
.Doc(R"doc(
|
||||
Performs beam search decoding on the logits given in input.
|
||||
|
||||
A note about the attribute merge_repeated: For the beam search decoder,
|
||||
this means that if consecutive entries in a beam are the same, only
|
||||
the first of these is emitted. That is, when the top path is "A B B B B",
|
||||
"A B" is returned if merge_repeated = True but "A B B B B" is
|
||||
returned if merge_repeated = False.
|
||||
|
||||
inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits.
|
||||
sequence_length: A vector containing sequence lengths, size `(batch)`.
|
||||
beam_width: A scalar >= 0 (beam search beam width).
|
||||
top_paths: A scalar >= 0, <= beam_width (controls output size).
|
||||
merge_repeated: If true, merge repeated classes in output.
|
||||
decoded_indices: A list (length: top_paths) of indices matrices. Matrix j,
|
||||
size `(total_decoded_outputs[j] x 2)`, has indices of a
|
||||
`SparseTensor<int64, 2>`. The rows store: [batch, time].
|
||||
decoded_values: A list (length: top_paths) of values vectors. Vector j,
|
||||
size `(length total_decoded_outputs[j])`, has the values of a
|
||||
`SparseTensor<int64, 2>`. The vector stores the decoded classes for beam j.
|
||||
decoded_shape: A list (length: top_paths) of shape vector. Vector j,
|
||||
size `(2)`, stores the shape of the decoded `SparseTensor[j]`.
|
||||
Its values are: `[batch_size, max_decoded_length[j]]`.
|
||||
log_probability: A matrix, shaped: `(batch_size x top_paths)`. The
|
||||
sequence log-probabilities.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -1608,6 +1608,160 @@ op {
|
||||
summary: "Return the reduction indices for computing gradients of s0 op s1 with broadcast."
|
||||
description: "This is typically used by gradient computations for a broadcasting operation."
|
||||
}
|
||||
op {
|
||||
name: "CTCBeamSearchDecoder"
|
||||
input_arg {
|
||||
name: "inputs"
|
||||
description: "3-D, shape: `(max_time x batch_size x num_classes)`, the logits."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "sequence_length"
|
||||
description: "A vector containing sequence lengths, size `(batch)`."
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "decoded_indices"
|
||||
description: "A list (length: top_paths) of indices matrices. Matrix j,\nsize `(total_decoded_outputs[j] x 2)`, has indices of a\n`SparseTensor<int64, 2>`. The rows store: [batch, time]."
|
||||
type: DT_INT64
|
||||
number_attr: "top_paths"
|
||||
}
|
||||
output_arg {
|
||||
name: "decoded_values"
|
||||
description: "A list (length: top_paths) of values vectors. Vector j,\nsize `(length total_decoded_outputs[j])`, has the values of a\n`SparseTensor<int64, 2>`. The vector stores the decoded classes for beam j."
|
||||
type: DT_INT64
|
||||
number_attr: "top_paths"
|
||||
}
|
||||
output_arg {
|
||||
name: "decoded_shape"
|
||||
description: "A list (length: top_paths) of shape vector. Vector j,\nsize `(2)`, stores the shape of the decoded `SparseTensor[j]`.\nIts values are: `[batch_size, max_decoded_length[j]]`."
|
||||
type: DT_INT64
|
||||
number_attr: "top_paths"
|
||||
}
|
||||
output_arg {
|
||||
name: "log_probability"
|
||||
description: "A matrix, shaped: `(batch_size x top_paths)`. The\nsequence log-probabilities."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
attr {
|
||||
name: "beam_width"
|
||||
type: "int"
|
||||
description: "A scalar >= 0 (beam search beam width)."
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "top_paths"
|
||||
type: "int"
|
||||
description: "A scalar >= 0, <= beam_width (controls output size)."
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "merge_repeated"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
description: "If true, merge repeated classes in output."
|
||||
}
|
||||
summary: "Performs beam search decoding on the logits given in input."
|
||||
description: "A note about the attribute merge_repeated: For the beam search decoder,\nthis means that if consecutive entries in a beam are the same, only\nthe first of these is emitted. That is, when the top path is \"A B B B B\",\n\"A B\" is returned if merge_repeated = True but \"A B B B B\" is\nreturned if merge_repeated = False."
|
||||
}
|
||||
op {
|
||||
name: "CTCGreedyDecoder"
|
||||
input_arg {
|
||||
name: "inputs"
|
||||
description: "3-D, shape: `(max_time x batch_size x num_classes)`, the logits."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "sequence_length"
|
||||
description: "A vector containing sequence lengths, size `(batch_size)`."
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "decoded_indices"
|
||||
description: "Indices matrix, size `(total_decoded_outputs x 2)`,\nof a `SparseTensor<int64, 2>`. The rows store: [batch, time]."
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "decoded_values"
|
||||
description: "Values vector, size: `(total_decoded_outputs)`,\nof a `SparseTensor<int64, 2>`. The vector stores the decoded classes."
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "decoded_shape"
|
||||
description: "Shape vector, size `(2)`, of the decoded SparseTensor.\nValues are: `[batch_size, max_decoded_length]`."
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "log_probability"
|
||||
description: "Matrix, size `(batch_size x 1)`, containing sequence\nlog-probabilities."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
attr {
|
||||
name: "merge_repeated"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "If True, merge repeated classes in output."
|
||||
}
|
||||
summary: "Performs greedy decoding on the logits given in inputs."
|
||||
description: "A note about the attribute merge_repeated: if enabled, when\nconsecutive logits\' maximum indices are the same, only the first of\nthese is emitted. Labeling the blank \'*\', the sequence \"A B B * B B\"\nbecomes \"A B\" if merge_repeated = True and \"A B B B B\" if\nmerge_repeated = False.\n\nRegardless of the value of merge_repeated, if the maximum index of a given\ntime and batch corresponds to the blank, index `(num_classes - 1)`, no new\nelement is emitted."
|
||||
}
|
||||
op {
|
||||
name: "CTCLoss"
|
||||
input_arg {
|
||||
name: "inputs"
|
||||
description: "3-D, shape: `(max_time x batch_size x num_classes)`, the logits."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "labels_indices"
|
||||
description: "The indices of a `SparseTensor<int32, 2>`.\n`labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for\n`(batch b, time t)`."
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "labels_values"
|
||||
description: "The values (labels) associated with the given batch and time."
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "sequence_length"
|
||||
description: "A vector containing sequence lengths (batch)."
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "loss"
|
||||
description: "A vector (batch) containing log-probabilities."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "gradient"
|
||||
description: "The gradient of `loss`. 3-D, shape:\n`(max_time x batch_size x num_classes)`."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
attr {
|
||||
name: "preprocess_collapse_repeated"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
description: "Scalar, if true then repeated labels are\ncollapsed prior to the CTC calculation."
|
||||
}
|
||||
attr {
|
||||
name: "ctc_merge_repeated"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
description: "Scalar. If set to false, *during* CTC calculation\nrepeated non-blank labels will not be merged and are interpreted as\nindividual labels. This is a simplified version of CTC."
|
||||
}
|
||||
summary: "Calculates the CTC Loss (log probability) for each batch entry. Also calculates"
|
||||
description: "the gradient. This class performs the softmax operation for you, so inputs\nshould be e.g. linear projections of outputs by an LSTM."
|
||||
}
|
||||
op {
|
||||
name: "Cast"
|
||||
input_arg {
|
||||
|
88
tensorflow/core/util/ctc/BUILD
Normal file
88
tensorflow/core/util/ctc/BUILD
Normal file
@ -0,0 +1,88 @@
|
||||
# Description: CTC, Connectionist Temporal Classification,
|
||||
# is a type of seq2seq loss. The libraries in this directory
|
||||
# implement the CTC loss and a number of CTC decoders.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_tests")
|
||||
|
||||
filegroup(
|
||||
name = "android_srcs",
|
||||
srcs = [
|
||||
"ctc_beam_entry.h",
|
||||
"ctc_beam_scorer.h",
|
||||
"ctc_beam_search.h",
|
||||
"ctc_decoder.h",
|
||||
"ctc_loss_util.h",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ctc",
|
||||
deps = [
|
||||
":ctc_beam_search_lib",
|
||||
":ctc_loss_calculator_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ctc_beam_search_lib",
|
||||
srcs = [
|
||||
"ctc_beam_entry.h",
|
||||
"ctc_beam_scorer.h",
|
||||
"ctc_beam_search.h",
|
||||
"ctc_decoder.h",
|
||||
],
|
||||
hdrs = [
|
||||
"ctc_beam_entry.h",
|
||||
"ctc_beam_scorer.h",
|
||||
"ctc_beam_search.h",
|
||||
"ctc_decoder.h",
|
||||
],
|
||||
deps = [
|
||||
":ctc_loss_util_lib",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
tests = [
|
||||
"ctc_beam_search_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":ctc_beam_search_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ctc_loss_calculator_lib",
|
||||
srcs = [
|
||||
"ctc_loss_calculator.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ctc_loss_calculator.h",
|
||||
],
|
||||
deps = [
|
||||
":ctc_loss_util_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ctc_loss_util_lib",
|
||||
hdrs = [
|
||||
"ctc_loss_util.h",
|
||||
],
|
||||
deps = ["//tensorflow/core:lib"],
|
||||
)
|
128
tensorflow/core/util/ctc/ctc_beam_entry.h
Normal file
128
tensorflow/core/util/ctc/ctc_beam_entry.h
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/ctc/ctc_loss_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
// The ctc_beam_search namespace holds several classes meant to be accessed only
|
||||
// in case of extending the CTCBeamSearch decoder to allow custom scoring
|
||||
// functions.
|
||||
//
|
||||
// BeamEntry is exposed through template arguments BeamScorer and BeamComparer
|
||||
// of CTCBeamSearch (ctc_beam_search.h).
|
||||
namespace ctc_beam_search {
|
||||
|
||||
struct EmptyBeamState {};
|
||||
|
||||
struct BeamProbability {
|
||||
BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {}
|
||||
void Reset() {
|
||||
total = kLogZero;
|
||||
blank = kLogZero;
|
||||
label = kLogZero;
|
||||
}
|
||||
float total;
|
||||
float blank;
|
||||
float label;
|
||||
};
|
||||
|
||||
template <class CTCBeamState = EmptyBeamState>
|
||||
struct BeamEntry {
|
||||
// Default constructor does not create a vector of children.
|
||||
BeamEntry() : parent(nullptr), label(-1) {}
|
||||
// Constructor giving parent, label, and number of children does
|
||||
// create a vector of children. The object pointed to by p
|
||||
// cannot be copied and should not be moved, otherwise parent will
|
||||
// become invalid.
|
||||
BeamEntry(BeamEntry* p, int l, int L, int t) : parent(p), label(l) {
|
||||
PopulateChildren(L);
|
||||
}
|
||||
inline bool Active() const { return newp.total != kLogZero; }
|
||||
inline bool HasChildren() const { return !children.empty(); }
|
||||
void PopulateChildren(int L) {
|
||||
CHECK(!HasChildren());
|
||||
children = std::vector<BeamEntry>(L);
|
||||
int ci = 0;
|
||||
for (auto& c : children) {
|
||||
// The current object cannot be copied, and should not be moved.
|
||||
// Otherwise the child's parent will become invalid.
|
||||
c.parent = this;
|
||||
c.label = ci;
|
||||
++ci;
|
||||
}
|
||||
}
|
||||
inline std::vector<BeamEntry>* Children() {
|
||||
CHECK(HasChildren());
|
||||
return &children;
|
||||
}
|
||||
inline const std::vector<BeamEntry>* Children() const {
|
||||
CHECK(HasChildren());
|
||||
return &children;
|
||||
}
|
||||
std::vector<int> LabelSeq(bool merge_repeated) const {
|
||||
std::vector<int> labels;
|
||||
int prev_label = -1;
|
||||
const BeamEntry* c = this;
|
||||
while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
|
||||
if (!merge_repeated || c->label != prev_label) {
|
||||
labels.push_back(c->label);
|
||||
}
|
||||
prev_label = c->label;
|
||||
c = c->parent;
|
||||
}
|
||||
std::reverse(labels.begin(), labels.end());
|
||||
return labels;
|
||||
}
|
||||
|
||||
BeamEntry<CTCBeamState>* parent;
|
||||
int label;
|
||||
std::vector<BeamEntry<CTCBeamState>> children;
|
||||
BeamProbability oldp;
|
||||
BeamProbability newp;
|
||||
CTCBeamState state;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
|
||||
};
|
||||
|
||||
// BeamComparer is the default beam comparer provided in CTCBeamSearch.
|
||||
template <class CTCBeamState = EmptyBeamState>
|
||||
class BeamComparer {
|
||||
public:
|
||||
virtual ~BeamComparer() {}
|
||||
virtual bool inline operator()(const BeamEntry<CTCBeamState>* a,
|
||||
const BeamEntry<CTCBeamState>* b) const {
|
||||
return a->newp.total > b->newp.total;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ctc_beam_search
|
||||
|
||||
} // namespace ctc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
|
94
tensorflow/core/util/ctc/ctc_beam_scorer.h
Normal file
94
tensorflow/core/util/ctc/ctc_beam_scorer.h
Normal file
@ -0,0 +1,94 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Collection of scoring classes that can be extended and provided to the
|
||||
// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a
|
||||
// language model).
|
||||
//
|
||||
// To build a custom scorer extend and implement the pure virtual methods from
|
||||
// BeamScorerInterface. The default CTC decoding behavior is implemented
|
||||
// through BaseBeamScorer.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SCORER_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SCORER_H_
|
||||
|
||||
#include "tensorflow/core/util/ctc/ctc_beam_entry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
// BeamScorerInterface can be subclassed and provided as a template argument to
|
||||
// CTCBeamSearchDecoder, if complex scoring is required. Its main purpose is to
|
||||
// provide a thin layer for integrating language model scoring easily.
|
||||
template <typename CTCBeamState>
|
||||
class BeamScorerInterface {
|
||||
public:
|
||||
virtual ~BeamScorerInterface() {}
|
||||
|
||||
// State initialization.
|
||||
virtual inline void InitializeState(CTCBeamState* root) const = 0;
|
||||
|
||||
// ExpandState is called when expanding a beam to one of its children.
|
||||
// Called at most once per child beam.
|
||||
virtual void ExpandState(const CTCBeamState& from_state, int from_label,
|
||||
CTCBeamState* to_state, int to_label) const = 0;
|
||||
|
||||
// ExpandStateEnd is called after decoding has finished. Its purpose is to
|
||||
// allow a final scoring of the beam in its current state, before resorting
|
||||
// and retrieving the TopN requested candidates. Called at most once per beam.
|
||||
virtual void ExpandStateEnd(CTCBeamState* state) const = 0;
|
||||
|
||||
// GetStateExpansionScore should be an inexpensive method to retrieve the
|
||||
// (cached) expansion score computed within ExpandState. The score is
|
||||
// multiplied (log-addition) with the input score at the current step from
|
||||
// the network.
|
||||
//
|
||||
// The score returned should be a log-probability.
|
||||
virtual float GetStateExpansionScore(const CTCBeamState& state,
|
||||
float previous_score) const = 0;
|
||||
|
||||
// GetStateEndExpansionScore should be an inexpensive method to retrieve the
|
||||
// (cached) expansion score computed within ExpandStateEnd. The score is
|
||||
// multiplied (log-addition) with the final probability of the beam.
|
||||
//
|
||||
// The score returned should be a log-probability.
|
||||
virtual float GetStateEndExpansionScore(const CTCBeamState& state) const = 0;
|
||||
};
|
||||
|
||||
// Base implementation of BeamScorer used by default by the decoder.
|
||||
template <typename CTCBeamState>
|
||||
class BaseBeamScorer : public BeamScorerInterface<CTCBeamState> {
|
||||
public:
|
||||
~BaseBeamScorer() override {}
|
||||
|
||||
// In the simplest case, no state expansion is done.
|
||||
void InitializeState(CTCBeamState* root) const override {}
|
||||
void ExpandState(const CTCBeamState& from_state, int from_label,
|
||||
CTCBeamState* to_state, int to_label) const override {}
|
||||
void ExpandStateEnd(CTCBeamState* state) const override {}
|
||||
// As there's no state expansion logic, the expansion score is zero.
|
||||
float GetStateExpansionScore(const CTCBeamState& state,
|
||||
float previous_score) const override {
|
||||
return previous_score;
|
||||
}
|
||||
float GetStateEndExpansionScore(const CTCBeamState& state) const override {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ctc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SCORER_H_
|
319
tensorflow/core/util/ctc/ctc_beam_search.h
Normal file
319
tensorflow/core/util/ctc/ctc_beam_search.h
Normal file
@ -0,0 +1,319 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/lib/gtl/top_n.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/ctc/ctc_beam_entry.h"
|
||||
#include "tensorflow/core/util/ctc/ctc_beam_scorer.h"
|
||||
#include "tensorflow/core/util/ctc/ctc_decoder.h"
|
||||
#include "tensorflow/core/util/ctc/ctc_loss_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
|
||||
class CTCBeamScorer = BaseBeamScorer<CTCBeamState>,
|
||||
typename CTCBeamComparer =
|
||||
ctc_beam_search::BeamComparer<CTCBeamState>>
|
||||
class CTCBeamSearchDecoder : public CTCDecoder {
|
||||
// Beam Search
|
||||
//
|
||||
// Example (GravesTh Fig. 7.5):
|
||||
// a -
|
||||
// P = [ 0.3 0.7 ] t = 0
|
||||
// [ 0.4 0.6 ] t = 1
|
||||
//
|
||||
// Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42
|
||||
// P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58
|
||||
//
|
||||
// In this case, Best Path decoding is suboptimal.
|
||||
//
|
||||
// For Beam Search, we use the following main recurrence relations:
|
||||
//
|
||||
// Relation 1:
|
||||
// ---------------------------------------------------------- Eq. 1
|
||||
// P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7)
|
||||
// + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7))
|
||||
// where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and
|
||||
// updated recursively in the beam entry.
|
||||
//
|
||||
// Relation 2:
|
||||
// ---------------------------------------------------------- Eq. 2
|
||||
// P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3)
|
||||
// for ? in a, b, d, ..., (not including c or the blank index),
|
||||
// and the recurrence starts from the beam entry for P(l=abc @ t=2).
|
||||
//
|
||||
// For this case, the length of the new sequence equals t+1 (t
|
||||
// starts at 0). This special case can be calculated as:
|
||||
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
|
||||
// but we calculate it recursively for speed purposes.
|
||||
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
|
||||
typedef ctc_beam_search::BeamProbability BeamProbability;
|
||||
|
||||
public:
|
||||
CTCBeamSearchDecoder(int num_classes, int beam_width)
|
||||
: CTCDecoder(num_classes, 1, false),
|
||||
beam_width_(beam_width),
|
||||
leaves_(beam_width),
|
||||
beam_scorer_(new CTCBeamScorer) {
|
||||
Reset();
|
||||
}
|
||||
|
||||
CTCBeamSearchDecoder(int num_classes, int beam_width, int batch_size,
|
||||
bool merge_repeated)
|
||||
: CTCDecoder(num_classes, batch_size, merge_repeated),
|
||||
beam_width_(beam_width),
|
||||
leaves_(beam_width),
|
||||
beam_scorer_(new CTCBeamScorer) {}
|
||||
|
||||
~CTCBeamSearchDecoder() override {}
|
||||
|
||||
// Run the hibernating beam search algorithm on the given input.
|
||||
void Decode(const CTCDecoder::SequenceLength& seq_len,
|
||||
const std::vector<CTCDecoder::Input>& input,
|
||||
std::vector<CTCDecoder::Output>* output,
|
||||
CTCDecoder::ScoreOutput* scores) override;
|
||||
|
||||
// Calculate the next step of the beam search and update the internal state.
|
||||
template <typename Vector>
|
||||
void Step(const Vector& log_input_t);
|
||||
|
||||
// Retrieve the beam scorer instance used during decoding.
|
||||
CTCBeamScorer* GetBeamScorer() { return beam_scorer_.get(); }
|
||||
|
||||
// Reset the beam search
|
||||
void Reset();
|
||||
|
||||
// Extract the top n paths at current time step
|
||||
void TopPaths(int n, std::vector<std::vector<int>>* paths,
|
||||
std::vector<float>* log_probs, bool merge_repeated) const;
|
||||
|
||||
private:
|
||||
int beam_width_;
|
||||
|
||||
gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
|
||||
std::unique_ptr<BeamEntry> beam_root_;
|
||||
std::unique_ptr<CTCBeamScorer> beam_scorer_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
|
||||
};
|
||||
|
||||
template <typename CTCBeamState, class CTCBeamScorer, typename CTCBeamComparer>
|
||||
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamScorer, CTCBeamComparer>::Decode(
|
||||
const CTCDecoder::SequenceLength& seq_len,
|
||||
const std::vector<CTCDecoder::Input>& input, std::vector<CTCDecoder::Output>* output,
|
||||
ScoreOutput* scores) {
|
||||
// Storage for top paths.
|
||||
std::vector<std::vector<int>> beams;
|
||||
std::vector<float> beam_log_probabilities;
|
||||
int top_n = output->size();
|
||||
|
||||
for (int b = 0; b < batch_size_; ++b) {
|
||||
int seq_len_b = seq_len[b];
|
||||
Reset();
|
||||
|
||||
for (int t = 0; t < seq_len_b; ++t) {
|
||||
// Pass log-probabilities for this example + time.
|
||||
Step(input[t].row(b));
|
||||
} // for (int t...
|
||||
|
||||
// O(n * log(n))
|
||||
std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
|
||||
leaves_.Reset();
|
||||
for (int i = 0; i < branches->size(); ++i) {
|
||||
BeamEntry* entry = (*branches)[i];
|
||||
beam_scorer_->ExpandStateEnd(&entry->state);
|
||||
entry->newp.total +=
|
||||
beam_scorer_->GetStateEndExpansionScore(entry->state);
|
||||
leaves_.push(entry);
|
||||
}
|
||||
|
||||
TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
|
||||
|
||||
CHECK_EQ(top_n, beam_log_probabilities.size());
|
||||
CHECK_EQ(beams.size(), beam_log_probabilities.size());
|
||||
|
||||
for (int i = 0; i < top_n; ++i) {
|
||||
// Copy output to the correct beam + batch
|
||||
(*output)[i][b].swap(beams[i]);
|
||||
(*scores)(b, i) = -beam_log_probabilities[i];
|
||||
}
|
||||
} // for (int b...
|
||||
}
|
||||
|
||||
template <typename CTCBeamState, class CTCBeamScorer, typename CTCBeamComparer>
|
||||
template <typename Vector>
|
||||
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamScorer, CTCBeamComparer>::Step(
|
||||
const Vector& raw_input) {
|
||||
Eigen::ArrayXf input = raw_input;
|
||||
// Remove the max for stability when performing log-prob calculations.
|
||||
input -= input.maxCoeff();
|
||||
|
||||
// Extract the beams sorted in decreasing new probability
|
||||
CHECK_EQ(num_classes_, input.size());
|
||||
|
||||
std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
|
||||
leaves_.Reset();
|
||||
|
||||
for (BeamEntry* b : *branches) {
|
||||
// P(.. @ t) becomes the new P(.. @ t-1)
|
||||
b->oldp = b->newp;
|
||||
}
|
||||
|
||||
for (BeamEntry* b : *branches) {
|
||||
if (b->parent != nullptr) { // if not the root
|
||||
if (b->parent->Active()) {
|
||||
// If last two sequence characters are identical:
|
||||
// Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5)
|
||||
// + Pblank(l=ac @ t=5))
|
||||
// else:
|
||||
// Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
|
||||
// + P(l=ab @ t=5))
|
||||
float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
|
||||
: b->parent->oldp.total;
|
||||
b->newp.label =
|
||||
LogSumExp(b->newp.label,
|
||||
beam_scorer_->GetStateExpansionScore(b->state, previous));
|
||||
}
|
||||
// Plabel(l=abc @ t=6) *= P(c @ 6)
|
||||
b->newp.label += input(b->label);
|
||||
}
|
||||
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
|
||||
b->newp.blank = b->oldp.total + input(blank_index_);
|
||||
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
|
||||
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
|
||||
|
||||
// Push the entry back to the top paths list.
|
||||
// Note, this will always fill leaves back up in sorted order.
|
||||
leaves_.push(b);
|
||||
}
|
||||
|
||||
// we need to resort branches in descending oldp order.
|
||||
|
||||
// branches is in descending oldp order because it was
|
||||
// originally in descending newp order and we copied newp to oldp.
|
||||
|
||||
// Grow new leaves
|
||||
for (BeamEntry* b : *branches) {
|
||||
// A new leaf (represented by its BeamProbability) is a candidate
|
||||
// iff its total probability is nonzero and either the beam list
|
||||
// isn't full, or the lowest probability entry in the beam has a
|
||||
// lower probability than the leaf.
|
||||
auto is_candidate = [this](const BeamProbability& prob) {
|
||||
return (prob.total > kLogZero &&
|
||||
(leaves_.size() < beam_width_ ||
|
||||
prob.total > leaves_.peek_bottom()->newp.total));
|
||||
};
|
||||
|
||||
if (!is_candidate(b->oldp)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!b->HasChildren()) {
|
||||
b->PopulateChildren(num_classes_ - 1);
|
||||
}
|
||||
|
||||
for (BeamEntry& c : *b->Children()) {
|
||||
if (!c.Active()) {
|
||||
// Pblank(l=abcd @ t=6) = 0
|
||||
c.newp.blank = kLogZero;
|
||||
// If new child label is identical to beam label:
|
||||
// Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
|
||||
// Otherwise:
|
||||
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
|
||||
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
|
||||
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
|
||||
c.newp.label = input(c.label) +
|
||||
beam_scorer_->GetStateExpansionScore(c.state, previous);
|
||||
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
|
||||
c.newp.total = c.newp.label;
|
||||
|
||||
if (is_candidate(c.newp)) {
|
||||
BeamEntry* bottom = leaves_.peek_bottom();
|
||||
leaves_.push(&c);
|
||||
if (leaves_.size() == beam_width_) {
|
||||
// Bottom is no longer in the beam search. Reset
|
||||
// its probability; signal it's no longer in the beam search.
|
||||
bottom->newp.Reset();
|
||||
}
|
||||
} else {
|
||||
// Deactivate child (signal it's not in the beam)
|
||||
c.oldp.Reset();
|
||||
c.newp.Reset();
|
||||
}
|
||||
} // if (!c.Active()) ...
|
||||
} // for (BeamEntry& c in children...
|
||||
} // for (BeamEntry* b...
|
||||
}
|
||||
|
||||
template <typename CTCBeamState, class CTCBeamScorer, typename CTCBeamComparer>
|
||||
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamScorer,
|
||||
CTCBeamComparer>::Reset() {
|
||||
leaves_.Reset();
|
||||
|
||||
// This beam root, and all of its children, will be in memory until
|
||||
// the next reset.
|
||||
beam_root_.reset(new BeamEntry(nullptr, -1, num_classes_ - 1, -1));
|
||||
beam_root_->newp.total = 0.0; // ln(1)
|
||||
beam_root_->newp.blank = 0.0; // ln(1)
|
||||
|
||||
// Add the root as the initial leaf.
|
||||
leaves_.push(beam_root_.get());
|
||||
|
||||
// Call initialize state on the root object.
|
||||
if (beam_scorer_) {
|
||||
beam_scorer_->InitializeState(&beam_root_->state);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CTCBeamState, class CTCBeamScorer, typename CTCBeamComparer>
|
||||
void CTCBeamSearchDecoder<CTCBeamState, CTCBeamScorer, CTCBeamComparer>::
|
||||
TopPaths(int n, std::vector<std::vector<int>>* paths,
|
||||
std::vector<float>* log_probs, bool merge_repeated) const {
|
||||
CHECK_NOTNULL(paths)->clear();
|
||||
CHECK_NOTNULL(log_probs)->clear();
|
||||
CHECK_LE(n, beam_width_) << "Requested more paths than the beam width.";
|
||||
CHECK_LE(n, leaves_.size()) << "Less leaves in the beam search "
|
||||
<< "than requested. Have you called Step()?";
|
||||
|
||||
gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n);
|
||||
|
||||
// O(beam_width_ * log(n)), space complexity is O(n)
|
||||
for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) {
|
||||
top_branches.push(*it);
|
||||
}
|
||||
// O(n * log(n))
|
||||
std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract());
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
BeamEntry* e((*branches)[i]);
|
||||
paths->push_back(e->LabelSeq(merge_repeated));
|
||||
log_probs->push_back(e->newp.total);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ctc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
|
180
tensorflow/core/util/ctc/ctc_beam_search_test.cc
Normal file
180
tensorflow/core/util/ctc/ctc_beam_search_test.cc
Normal file
@ -0,0 +1,180 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This test illustrates how to make use of the CTCBeamSearchDecoder using a
|
||||
// custom BeamScorer and BeamState based on a dictionary with a few artificial
|
||||
// words.
|
||||
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
|
||||
|
||||
#include <cmath>
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
typedef std::vector<std::vector<std::vector<float>>> TestData;
|
||||
using tensorflow::ctc::CTCBeamSearchDecoder;
|
||||
using tensorflow::ctc::CTCDecoder;
|
||||
|
||||
// The HistoryBeamState is used to keep track of the current candidate and
|
||||
// caches the expansion score (needed by the scorer).
|
||||
struct HistoryBeamState {
|
||||
float score;
|
||||
std::vector<int> labels;
|
||||
};
|
||||
|
||||
// DictionaryBeamScorer essentially favors candidates that can still become
|
||||
// dictionary words. As soon as a beam candidate is not a dictionary word or
|
||||
// a prefix of a dictionary word it gets a low probability at each step.
|
||||
//
|
||||
// The dictionary itself is hard-coded a static const variable of the class.
|
||||
class DictionaryBeamScorer
|
||||
: public tensorflow::ctc::BeamScorerInterface<HistoryBeamState> {
|
||||
public:
|
||||
void InitializeState(HistoryBeamState* root) const override {
|
||||
root->score = 0;
|
||||
}
|
||||
|
||||
void ExpandState(const HistoryBeamState& from_state, int from_label,
|
||||
HistoryBeamState* to_state, int to_label) const override {
|
||||
// Keep track of the current complete candidate by storing the labels along
|
||||
// the expansion path in the beam state.
|
||||
to_state->labels.push_back(to_label);
|
||||
SetStateScoreAccordingToDict(to_state);
|
||||
}
|
||||
|
||||
void ExpandStateEnd(HistoryBeamState* state) const override {
|
||||
SetStateScoreAccordingToDict(state);
|
||||
}
|
||||
|
||||
float GetStateExpansionScore(const HistoryBeamState& state,
|
||||
float previous_score) const override {
|
||||
return previous_score + state.score;
|
||||
}
|
||||
|
||||
float GetStateEndExpansionScore(
|
||||
const HistoryBeamState& state) const override {
|
||||
return state.score;
|
||||
}
|
||||
|
||||
// Simple dictionary used when scoring the beams to check if they are prefixes
|
||||
// of dictionary words (see SetStateScoreAccordingToDict below).
|
||||
static const std::vector<std::vector<int>> dictionary_;
|
||||
|
||||
private:
|
||||
void SetStateScoreAccordingToDict(HistoryBeamState* state) const;
|
||||
};
|
||||
|
||||
const std::vector<std::vector<int>> DictionaryBeamScorer::dictionary_ = {
|
||||
{3}, {3, 1}};
|
||||
|
||||
void DictionaryBeamScorer::SetStateScoreAccordingToDict(
|
||||
HistoryBeamState* state) const {
|
||||
// Check if the beam can still be a dictionary word (e.g. prefix of one).
|
||||
const std::vector<int>& candidate = state->labels;
|
||||
for (int w = 0; w < dictionary_.size(); ++w) {
|
||||
const std::vector<int>& word = dictionary_[w];
|
||||
// If the length of the current beam is already larger, skip.
|
||||
if (candidate.size() > word.size()) {
|
||||
continue;
|
||||
}
|
||||
if (std::equal(word.begin(), word.begin() + candidate.size(),
|
||||
candidate.begin())) {
|
||||
state->score = std::log(1.0);
|
||||
return;
|
||||
}
|
||||
}
|
||||
// At this point, the candidate certainly can't be in the dictionary.
|
||||
state->score = std::log(0.01);
|
||||
}
|
||||
|
||||
TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) {
|
||||
const int batch_size = 1;
|
||||
const int timesteps = 5;
|
||||
const int top_paths = 3;
|
||||
const int num_classes = 6;
|
||||
|
||||
// Plain decoder using hibernating beam search algorithm.
|
||||
CTCBeamSearchDecoder<> decoder(num_classes, 10 * top_paths, batch_size,
|
||||
false);
|
||||
|
||||
// Dictionary decoder, allowing only two dictionary words : {3}, {3, 1}.
|
||||
CTCBeamSearchDecoder<HistoryBeamState, DictionaryBeamScorer>
|
||||
dictionary_decoder(num_classes, top_paths, batch_size, false);
|
||||
|
||||
// Raw data containers (arrays of floats, ints, etc.).
|
||||
int sequence_lengths[batch_size] = {timesteps};
|
||||
float input_data_mat[timesteps][batch_size][num_classes] = {
|
||||
{{0, 0.6, 0, 0.4, 0, 0}},
|
||||
{{0, 0.5, 0, 0.5, 0, 0}},
|
||||
{{0, 0.4, 0, 0.6, 0, 0}},
|
||||
{{0, 0.4, 0, 0.6, 0, 0}},
|
||||
{{0, 0.4, 0, 0.6, 0, 0}}};
|
||||
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < timesteps; ++t) {
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
for (int c = 0; c < num_classes; ++c) {
|
||||
input_data_mat[t][b][c] = std::log(input_data_mat[t][b][c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Plain output, without any additional scoring.
|
||||
std::vector<CTCDecoder::Output> expected_output = {
|
||||
{{1, 3}, {1, 3, 1}, {3, 1, 3}},
|
||||
};
|
||||
|
||||
// Dictionary outputs: preference for dictionary candidates. The
|
||||
// second-candidate is there, despite it not being a dictionary word, due to
|
||||
// stronger probability in the input to the decoder.
|
||||
std::vector<CTCDecoder::Output> expected_dict_output = {
|
||||
{{3}, {1, 3}, {3, 1}},
|
||||
};
|
||||
|
||||
// Convert data containers to the formatat accepted by the decoder, simply
|
||||
// mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf,
|
||||
// using Eigen::Map.
|
||||
Eigen::Map<const Eigen::ArrayXi> seq_len(&sequence_lengths[0], batch_size);
|
||||
std::vector<Eigen::Map<const Eigen::MatrixXf>> inputs;
|
||||
for (int t = 0; t < timesteps; ++t) {
|
||||
inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes);
|
||||
}
|
||||
|
||||
// Prepare containers for output and scores.
|
||||
std::vector<CTCDecoder::Output> outputs(top_paths);
|
||||
for (CTCDecoder::Output& output : outputs) {
|
||||
output.resize(batch_size);
|
||||
}
|
||||
float score[batch_size][top_paths] = {{0.0}};
|
||||
Eigen::Map<Eigen::MatrixXf> scores(&score[0][0], batch_size, top_paths);
|
||||
|
||||
decoder.Decode(seq_len, inputs, &outputs, &scores);
|
||||
for (int path = 0; path < top_paths; ++path) {
|
||||
EXPECT_EQ(outputs[path][0], expected_output[0][path]);
|
||||
}
|
||||
|
||||
// Prepare dictionary outputs.
|
||||
std::vector<CTCDecoder::Output> dict_outputs(top_paths);
|
||||
for (CTCDecoder::Output& output : dict_outputs) {
|
||||
output.resize(batch_size);
|
||||
}
|
||||
dictionary_decoder.Decode(seq_len, inputs, &dict_outputs, &scores);
|
||||
for (int path = 0; path < top_paths; ++path) {
|
||||
EXPECT_EQ(dict_outputs[path][0], expected_dict_output[0][path]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
102
tensorflow/core/util/ctc/ctc_decoder.h
Normal file
102
tensorflow/core/util/ctc/ctc_decoder.h
Normal file
@ -0,0 +1,102 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
// The CTCDecoder is an abstract interface to be implemented when providing a
|
||||
// decoding method on the timestep output of a RNN trained with CTC loss.
|
||||
//
|
||||
// The two types of decoding available are:
|
||||
// - greedy path, through the CTCGreedyDecoder
|
||||
// - beam search, through the CTCBeamSearchDecoder
|
||||
class CTCDecoder {
|
||||
public:
|
||||
typedef Eigen::Map<const Eigen::ArrayXi> SequenceLength;
|
||||
typedef Eigen::Map<const Eigen::MatrixXf> Input;
|
||||
typedef std::vector<std::vector<int>> Output;
|
||||
typedef Eigen::Map<Eigen::MatrixXf> ScoreOutput;
|
||||
|
||||
CTCDecoder(int num_classes, int batch_size, bool merge_repeated)
|
||||
: num_classes_(num_classes),
|
||||
blank_index_(num_classes - 1),
|
||||
batch_size_(batch_size),
|
||||
merge_repeated_(merge_repeated) {}
|
||||
|
||||
virtual ~CTCDecoder() {}
|
||||
|
||||
// Dimensionality of the input/output is expected to be:
|
||||
// - seq_len[b] - b = 0 to batch_size_
|
||||
// - input[t].rows(b) - t = 0 to timesteps; b = 0 t batch_size_
|
||||
// - output.size() specifies the number of beams to be returned.
|
||||
// - scores(b, i) - b = 0 to batch_size; i = 0 to output.size()
|
||||
virtual void Decode(const SequenceLength& seq_len,
|
||||
const std::vector<Input>& input,
|
||||
std::vector<Output>* output, ScoreOutput* scores) = 0;
|
||||
|
||||
int batch_size() { return batch_size_; }
|
||||
int num_classes() { return num_classes_; }
|
||||
|
||||
protected:
|
||||
int num_classes_;
|
||||
int blank_index_;
|
||||
int batch_size_;
|
||||
bool merge_repeated_;
|
||||
};
|
||||
|
||||
// CTCGreedyDecoder is an implementation of the simple best path decoding
|
||||
// algorithm, selecting at each timestep the most likely class at each timestep.
|
||||
class CTCGreedyDecoder : public CTCDecoder {
|
||||
public:
|
||||
CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated)
|
||||
: CTCDecoder(num_classes, batch_size, merge_repeated) {}
|
||||
|
||||
void Decode(const CTCDecoder::SequenceLength& seq_len,
|
||||
const std::vector<CTCDecoder::Input>& input,
|
||||
std::vector<CTCDecoder::Output>* output,
|
||||
CTCDecoder::ScoreOutput* scores) override {
|
||||
// For each batch entry, identify the transitions
|
||||
for (int b = 0; b < batch_size_; ++b) {
|
||||
int seq_len_b = seq_len[b];
|
||||
// Only writing to beam 0
|
||||
std::vector<int>& output_b = (*output)[0][b];
|
||||
|
||||
int prev_class_ix = -1;
|
||||
std::vector<int> transcription;
|
||||
(*scores)(b, 0) = 0;
|
||||
for (int t = 0; t < seq_len_b; ++t) {
|
||||
auto row = input[t].row(b);
|
||||
int max_class_ix;
|
||||
(*scores)(b, 0) += -row.maxCoeff(&max_class_ix);
|
||||
if (max_class_ix != blank_index_ &&
|
||||
!(merge_repeated_ && max_class_ix == prev_class_ix)) {
|
||||
output_b.push_back(max_class_ix);
|
||||
transcription.push_back(max_class_ix);
|
||||
}
|
||||
prev_class_ix = max_class_ix;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ctc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
|
185
tensorflow/core/util/ctc/ctc_loss_calculator.cc
Normal file
185
tensorflow/core/util/ctc/ctc_loss_calculator.cc
Normal file
@ -0,0 +1,185 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
|
||||
// Starting with t = 0 instead of t = 1 used in the text.
|
||||
// Based on Kanishka's CTC.
|
||||
void CTCLossCalculator::CalculateForwardVariables(
|
||||
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
|
||||
Matrix* log_alpha) const {
|
||||
// Number of cols is the number of time steps = number of cols in target
|
||||
// after the output delay.
|
||||
log_alpha->setConstant(kLogZero);
|
||||
|
||||
int U = l_prime.size();
|
||||
int T = log_alpha->cols();
|
||||
|
||||
CHECK_EQ(U, log_alpha->rows());
|
||||
|
||||
// Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
|
||||
log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_));
|
||||
// Below, l_prime[1] == labels[0]
|
||||
auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
|
||||
log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_));
|
||||
|
||||
for (int t = 1; t < T; ++t) {
|
||||
// If there is not enough time to output the remaining labels or
|
||||
// some labels have been skipped, then let log_alpha(u, t) continue to
|
||||
// be kLogZero.
|
||||
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
|
||||
++u) {
|
||||
// Begin (GravesTh) Eq 7.9
|
||||
// Add in the u, t - 1 term.
|
||||
float sum_log_alpha = kLogZero;
|
||||
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
|
||||
sum_log_alpha = log_alpha->coeff(u, t - 1);
|
||||
}
|
||||
|
||||
// Add in the u - 1, t - 1 term.
|
||||
if (u > 0) {
|
||||
sum_log_alpha =
|
||||
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
|
||||
}
|
||||
|
||||
// Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
|
||||
if (u > 1) {
|
||||
const bool matching_labels_merge =
|
||||
ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
|
||||
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
|
||||
sum_log_alpha =
|
||||
LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
|
||||
}
|
||||
}
|
||||
// Multiply the summed alphas with the activation log probability.
|
||||
log_alpha->coeffRef(u, t) =
|
||||
log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
|
||||
} // End (GravesTh) Eq 7.9.
|
||||
}
|
||||
}
|
||||
|
||||
// Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
|
||||
void CTCLossCalculator::CalculateBackwardVariables(
|
||||
const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
|
||||
Matrix* log_beta) const {
|
||||
// Number of cols is the number of time steps = number of cols in target.
|
||||
// Matrix log_beta =
|
||||
// Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
|
||||
// kLogZero);
|
||||
log_beta->setConstant(kLogZero);
|
||||
int T = log_beta->cols();
|
||||
int U = l_prime.size();
|
||||
CHECK_EQ(U, log_beta->rows());
|
||||
|
||||
// Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
|
||||
for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
|
||||
|
||||
for (int t = T - 1 - 1; t >= 0; --t) {
|
||||
// If there is not enough time to output the remaining labels or
|
||||
// some labels have been skipped, then let log_beta(u, t) continue to
|
||||
// be kLogZero.
|
||||
for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
|
||||
++u) {
|
||||
// Begin (GravesTh) Eq 7.15
|
||||
// Add in the u, t + 1 term.
|
||||
if (ctc_merge_repeated || l_prime[u] == blank_index_) {
|
||||
log_beta->coeffRef(u, t) =
|
||||
LogSumExp(log_beta->coeff(u, t),
|
||||
log_beta->coeff(u, t + 1) +
|
||||
log(y(l_prime[u], output_delay_ + t + 1)));
|
||||
}
|
||||
|
||||
// Add in the u + 1, t + 1 term.
|
||||
if (u + 1 < U) {
|
||||
log_beta->coeffRef(u, t) =
|
||||
LogSumExp(log_beta->coeff(u, t),
|
||||
log_beta->coeff(u + 1, t + 1) +
|
||||
log(y(l_prime[u + 1], output_delay_ + t + 1)));
|
||||
}
|
||||
|
||||
// Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
|
||||
if (u + 2 < U) {
|
||||
const bool matching_labels_merge =
|
||||
ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
|
||||
if (l_prime[u] != blank_index_ && !matching_labels_merge) {
|
||||
// Add in u + 2 term.
|
||||
log_beta->coeffRef(u, t) =
|
||||
LogSumExp(log_beta->coeff(u, t),
|
||||
log_beta->coeff(u + 2, t + 1) +
|
||||
log(y(l_prime[u + 2], output_delay_ + t + 1)));
|
||||
}
|
||||
} // End (GravesTh) Eq. 7.15
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Using (GravesTh) Eq 7.26 & 7.34.
|
||||
void CTCLossCalculator::CalculateGradient(const std::vector<int>& l_prime,
|
||||
const Matrix& y,
|
||||
const Matrix& log_alpha,
|
||||
const Matrix& log_beta,
|
||||
float log_p_z_x, Matrix* dy) const {
|
||||
// Only working with the leftmost part of dy for this batch element.
|
||||
auto dy_b = dy->leftCols(y.cols());
|
||||
|
||||
// It is possible that no valid path is found if the activations for the
|
||||
// targets are zero.
|
||||
if (log_p_z_x == kLogZero) {
|
||||
LOG(WARNING) << "No valid path found.";
|
||||
dy_b = y;
|
||||
return;
|
||||
}
|
||||
|
||||
int L = y.rows();
|
||||
int T = y.cols();
|
||||
int U = l_prime.size();
|
||||
|
||||
for (int t = 0; t < T - output_delay_; ++t) {
|
||||
Array prob_sum(L);
|
||||
prob_sum.setConstant(kLogZero);
|
||||
|
||||
for (int u = 0; u < U; ++u) {
|
||||
int l = l_prime[u];
|
||||
prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
|
||||
}
|
||||
|
||||
for (int l = 0; l < L; ++l) {
|
||||
// Negative term in (GravesTh) Eq 7.28.
|
||||
float negative_term = expf(prob_sum[l] - log_p_z_x);
|
||||
|
||||
dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CTCLossCalculator::GetLPrimeIndices(const std::vector<int>& l,
|
||||
std::vector<int>* l_prime) const {
|
||||
// Assumption is that l_prime is empty.
|
||||
l_prime->reserve(2 * l.size() + 1);
|
||||
|
||||
for (auto label : l) {
|
||||
l_prime->push_back(blank_index_);
|
||||
l_prime->push_back(label);
|
||||
}
|
||||
// Add final blank to l'.
|
||||
l_prime->push_back(blank_index_);
|
||||
}
|
||||
|
||||
} // namespace ctc
|
||||
} // namespace tensorflow
|
318
tensorflow/core/util/ctc/ctc_loss_calculator.h
Normal file
318
tensorflow/core/util/ctc/ctc_loss_calculator.h
Normal file
@ -0,0 +1,318 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/util/ctc/ctc_loss_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
using strings::StrCat;
|
||||
|
||||
class CTCLossCalculator {
|
||||
// Connectionist Temporal Classification Loss
|
||||
//
|
||||
// Implementation by kanishkarao@, posenhuang@, and ebrevdo@.
|
||||
//
|
||||
// The CTC Loss layer learns a *transition* probability value for each
|
||||
// input time step. The transitions are on the class alphabet
|
||||
// {0, 1, ..., N-2}
|
||||
// where N is the depth of the input layer (the size of the alphabet is N-1).
|
||||
// Note: The token N-1 is reserved for the "no transition" output, so
|
||||
// make sure that your input layer has a depth that's one larger than
|
||||
// the set of classes you're training on. Also make sure that your
|
||||
// training labels do not have a class value of N-1, as training will skip
|
||||
// these examples.
|
||||
//
|
||||
// Reference materials:
|
||||
// GravesTh: Alex Graves, "Supervised Sequence Labelling with Recurrent
|
||||
// Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen.
|
||||
public:
|
||||
typedef std::vector<std::vector<int>> LabelSequences;
|
||||
typedef Eigen::MatrixXf Matrix;
|
||||
typedef Eigen::ArrayXf Array;
|
||||
typedef Eigen::Map<const Eigen::MatrixXf> InputMap;
|
||||
typedef Eigen::Map<Eigen::MatrixXf> OutputMap;
|
||||
|
||||
CTCLossCalculator(int blank_index, int output_delay)
|
||||
: blank_index_(blank_index), output_delay_(output_delay) {}
|
||||
|
||||
template <typename VectorIn, typename VectorOut, typename MatrixIn,
|
||||
typename MatrixOut>
|
||||
Status CalculateLoss(const VectorIn& seq_len, const LabelSequences& labels,
|
||||
const std::vector<MatrixIn>& inputs,
|
||||
bool preprocess_collapse_repeated,
|
||||
bool ctc_merge_repeated, VectorOut* loss,
|
||||
std::vector<MatrixOut>* gradients) const;
|
||||
|
||||
private:
|
||||
void CalculateForwardVariables(const std::vector<int>& l_prime,
|
||||
const Matrix& y, bool ctc_merge_repeated,
|
||||
Matrix* log_alpha) const;
|
||||
|
||||
void CalculateBackwardVariables(const std::vector<int>& l_prime,
|
||||
const Matrix& y, bool ctc_merge_repeated,
|
||||
Matrix* log_beta) const;
|
||||
|
||||
void CalculateGradient(const std::vector<int>& l_prime, const Matrix& y,
|
||||
const Matrix& log_alpha, const Matrix& log_beta,
|
||||
float log_p_z_x, Matrix* dy) const;
|
||||
|
||||
void GetLPrimeIndices(const std::vector<int>& l,
|
||||
std::vector<int>* l_prime) const;
|
||||
|
||||
// Helper function that calculates the l_prime indices for all
|
||||
// batches at the same time, and identifies errors for any given
|
||||
// batch. Return value:
|
||||
// max_{b in batch_size} l_primes[b].size()
|
||||
template <typename Vector>
|
||||
Status PopulateLPrimes(bool preprocess_collapse_repeated, int batch_size,
|
||||
int num_classes, const Vector& seq_len,
|
||||
const LabelSequences& labels, size_t* max_u_prime,
|
||||
LabelSequences* l_primes) const;
|
||||
|
||||
// Utility indices for the CTC algorithm.
|
||||
int blank_index_;
|
||||
|
||||
// Delay for target labels in time steps.
|
||||
// The delay in time steps before the output sequence.
|
||||
const int output_delay_;
|
||||
};
|
||||
|
||||
template <typename VectorIn, typename VectorOut, typename MatrixIn,
|
||||
typename MatrixOut>
|
||||
Status CTCLossCalculator::CalculateLoss(
|
||||
const VectorIn& seq_len, const LabelSequences& labels,
|
||||
const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated,
|
||||
bool ctc_merge_repeated, VectorOut* loss,
|
||||
std::vector<MatrixOut>* gradients) const {
|
||||
auto num_time_steps = inputs.size();
|
||||
|
||||
if (loss == nullptr) {
|
||||
return errors::InvalidArgument("loss == nullptr");
|
||||
}
|
||||
|
||||
bool requires_backprop = (gradients != nullptr);
|
||||
|
||||
auto batch_size = inputs[0].rows();
|
||||
auto num_classes = inputs[0].cols();
|
||||
|
||||
if (loss->size() != batch_size) {
|
||||
return errors::InvalidArgument("loss.size() != batch_size");
|
||||
}
|
||||
loss->setZero();
|
||||
|
||||
for (int t = 1; t < num_time_steps; ++t) {
|
||||
if (inputs[t].rows() != batch_size) {
|
||||
return errors::InvalidArgument("Expected batch size at t: ", t,
|
||||
" to be: ", batch_size, " but got: ",
|
||||
inputs[t].rows());
|
||||
}
|
||||
if (inputs[t].cols() != num_classes) {
|
||||
return errors::InvalidArgument("Expected class count at t: ", t,
|
||||
" to be: ", num_classes, " but got: ",
|
||||
inputs[t].cols());
|
||||
}
|
||||
}
|
||||
|
||||
// Check validity of sequence_length array values.
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
if (seq_len(b) < 0) {
|
||||
return errors::InvalidArgument("seq_len(", b, ") < 0");
|
||||
}
|
||||
if (seq_len(b) > num_time_steps) {
|
||||
return errors::InvalidArgument("seq_len(", b, ") > num_time_steps");
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate the modified label sequence l' for each batch element,
|
||||
// and calculate the maximum necessary allocation size.
|
||||
LabelSequences l_primes(batch_size);
|
||||
size_t max_u_prime = 0;
|
||||
Status l_p_ret =
|
||||
PopulateLPrimes(preprocess_collapse_repeated, batch_size, num_classes,
|
||||
seq_len, labels, &max_u_prime, &l_primes);
|
||||
if (!l_p_ret.ok()) {
|
||||
return l_p_ret;
|
||||
}
|
||||
|
||||
// For each batch element, log(alpha) and log(beta). Here we provide enough
|
||||
// storage for the maximum possible size.
|
||||
// row size is: u_prime == l_prime.size()
|
||||
// col size is: seq_len[b] - output_delay_
|
||||
Matrix log_alpha(max_u_prime, num_time_steps - output_delay_);
|
||||
Matrix log_beta(max_u_prime, num_time_steps - output_delay_);
|
||||
|
||||
// Work matrices, pre-allocated to maximum sizes
|
||||
Matrix y(num_classes, num_time_steps);
|
||||
Matrix dy;
|
||||
if (requires_backprop) dy = Matrix::Zero(y.rows(), y.cols());
|
||||
|
||||
// CTC is calcuated one batch element at a time
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
if (seq_len(b) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// For this batch, we'll only work with this shortened sequence_length.
|
||||
Matrix y_b = y.leftCols(seq_len(b));
|
||||
|
||||
const std::vector<int>& l_prime = l_primes[b];
|
||||
|
||||
// For this batch, we'll only work with log_alpha, log_beta matrices of
|
||||
// the necessary size.
|
||||
Matrix log_alpha_b =
|
||||
log_alpha.topLeftCorner(l_prime.size(), seq_len(b) - output_delay_);
|
||||
Matrix log_beta_b =
|
||||
log_beta.topLeftCorner(l_prime.size(), seq_len(b) - output_delay_);
|
||||
|
||||
// Convert label from DistBelief
|
||||
// y, prob are in num_classes x num_time_steps
|
||||
// Output activations.
|
||||
Eigen::ArrayXf y_b_col;
|
||||
for (int t = 0; t < seq_len(b); t++) {
|
||||
// Calculate the softmax of y_b. Use double precision
|
||||
// arithmetic for the sum.
|
||||
float max_coeff = inputs[t].row(b).maxCoeff();
|
||||
y_b_col = (inputs[t].row(b).array() - max_coeff).exp();
|
||||
y_b.col(t) = y_b_col / y_b_col.sum();
|
||||
}
|
||||
|
||||
// Compute forward, backward.
|
||||
// Forward variables.
|
||||
CalculateForwardVariables(l_prime, y_b, ctc_merge_repeated, &log_alpha_b);
|
||||
// Backward variables.
|
||||
CalculateBackwardVariables(l_prime, y_b, ctc_merge_repeated, &log_beta_b);
|
||||
|
||||
// The loss is computed as the log(p(z|x)) between the target and
|
||||
// prediction. Do lazy evaluation of log_prob here.
|
||||
float log_p_z_x = kLogZero;
|
||||
for (int u = 0; u < l_prime.size(); ++u) {
|
||||
// (GravesTh) Eq 7.26, sum over all paths for t = 0.
|
||||
log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0));
|
||||
}
|
||||
|
||||
(*loss)(b) = -log_p_z_x; // Use negative log loss for display.
|
||||
|
||||
// We compute the derivative if needed.
|
||||
if (requires_backprop) {
|
||||
// Gradients with respect to input activations.
|
||||
// Calculate gradient.
|
||||
dy.setZero();
|
||||
CalculateGradient(l_prime, y_b, log_alpha_b, log_beta_b, log_p_z_x, &dy);
|
||||
|
||||
// Convert gradient for current sample to DistBelief.
|
||||
for (int t = 0; t < seq_len(b); t++) {
|
||||
(*gradients)[t].row(b).array() = dy.col(t);
|
||||
}
|
||||
}
|
||||
} // for (int b = ...
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Vector>
|
||||
Status CTCLossCalculator::PopulateLPrimes(bool preprocess_collapse_repeated,
|
||||
int batch_size, int num_classes,
|
||||
const Vector& seq_len,
|
||||
const LabelSequences& labels,
|
||||
size_t* max_u_prime,
|
||||
LabelSequences* l_primes) const {
|
||||
// labels is a Label array of size batch_size
|
||||
if (labels.size() != batch_size) {
|
||||
return errors::InvalidArgument("labels.size() != batch_size: ",
|
||||
labels.size(), " vs. ", batch_size);
|
||||
}
|
||||
|
||||
*max_u_prime = 0; // keep track of longest l' modified label sequence.
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
// Assume label is in Label proto
|
||||
const std::vector<int>& label = labels[b];
|
||||
if (label.size() == 0) {
|
||||
return errors::InvalidArgument("Labels length is zero in batch ", b);
|
||||
}
|
||||
|
||||
// If debugging: output the labels coming into training.
|
||||
//
|
||||
VLOG(2) << "label for batch: " << b << ": " << str_util::Join(label, " ");
|
||||
|
||||
// Target indices, length = U.
|
||||
std::vector<int> l;
|
||||
|
||||
// Convert label from DistBelief
|
||||
bool finished_sequence = false;
|
||||
for (int i = 0; i < label.size(); ++i) {
|
||||
if (i == 0 || !preprocess_collapse_repeated || label[i] != label[i - 1]) {
|
||||
if (label[i] >= num_classes - 1) {
|
||||
finished_sequence = true;
|
||||
} else {
|
||||
if (finished_sequence) {
|
||||
// Saw an invalid sequence with non-null following null
|
||||
// labels.
|
||||
return errors::InvalidArgument(
|
||||
"Saw a non-null label (index >= num_classes - 1) "
|
||||
"following a ",
|
||||
"null label, batch: ", b, " num_classes: ", num_classes,
|
||||
" labels: ", str_util::Join(l, ","));
|
||||
}
|
||||
l.push_back(label[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure there is enough time to output the target indices.
|
||||
int time = seq_len(b) - output_delay_;
|
||||
int required_time = label.size();
|
||||
for (int l_i : l) {
|
||||
if (l_i < 0) {
|
||||
return errors::InvalidArgument(
|
||||
"All labels must be nonnegative integers, batch: ", b, " labels: ",
|
||||
str_util::Join(l, ","));
|
||||
} else if (l_i >= num_classes) {
|
||||
return errors::InvalidArgument(
|
||||
"No label may be greater than num_classes. ", "num_classes: ",
|
||||
num_classes, ", batch: ", b, " labels: ", str_util::Join(l, ","));
|
||||
}
|
||||
}
|
||||
if (required_time > time) {
|
||||
return errors::InvalidArgument(
|
||||
"Not enough time for target transition sequence ("
|
||||
"required: ",
|
||||
required_time, ", available: ", time,
|
||||
"), skipping data instance in batch: ", b);
|
||||
}
|
||||
|
||||
// Target indices with blanks before each index and a blank at the end.
|
||||
// Length U' = 2U + 1.
|
||||
// Convert l to l_prime
|
||||
GetLPrimeIndices(l, &l_primes->at(b));
|
||||
*max_u_prime = std::max(*max_u_prime, l_primes->at(b).size());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ctc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_
|
46
tensorflow/core/util/ctc/ctc_loss_util.h
Normal file
46
tensorflow/core/util/ctc/ctc_loss_util.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ctc {
|
||||
|
||||
const float kLogZero = -std::numeric_limits<float>::infinity();
|
||||
|
||||
// Add logarithmic probabilities using:
|
||||
// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a)))
|
||||
// The two inputs are assumed to be log probabilities.
|
||||
// (GravesTh) Eq. 7.18
|
||||
inline float LogSumExp(float log_prob_1, float log_prob_2) {
|
||||
// Always have 'b' be the smaller number to avoid the exponential from
|
||||
// blowing up.
|
||||
if (log_prob_1 == kLogZero && log_prob_2 == kLogZero) {
|
||||
return kLogZero;
|
||||
} else {
|
||||
return (log_prob_1 > log_prob_2)
|
||||
? log_prob_1 + log1pf(expf(log_prob_2 - log_prob_1))
|
||||
: log_prob_2 + log1pf(expf(log_prob_1 - log_prob_2));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ctc
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
|
@ -483,6 +483,16 @@ tf_gen_op_wrapper_py(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "ctc_ops",
|
||||
hidden = [
|
||||
"CTCLoss",
|
||||
"CTCGreedyDecoder",
|
||||
"CTCBeamSearchDecoder",
|
||||
],
|
||||
require_shape_functions = True,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "data_flow_ops",
|
||||
hidden = [
|
||||
@ -728,6 +738,7 @@ py_library(
|
||||
"ops/gen_array_ops.py",
|
||||
"ops/gen_attention_ops.py",
|
||||
"ops/gen_control_flow_ops.py",
|
||||
"ops/gen_ctc_ops.py",
|
||||
"ops/gen_data_flow_ops.py",
|
||||
"ops/gen_image_ops.py",
|
||||
"ops/gen_io_ops.py",
|
||||
|
Loading…
Reference in New Issue
Block a user