Move CTC out of contrib and document.
Change: 125022295
This commit is contained in:
parent
19376f7010
commit
67d3c915a8
RELEASE.md
tensorflow
@ -1,6 +1,8 @@
|
||||
# Changes Since Last Release
|
||||
|
||||
## Features & Improvements
|
||||
* Connectionist Temporal Classification ops are now "official" (see, e.g.,
|
||||
`tf.nn.ctc_loss`)
|
||||
* The RNN api is finally "official" (see, e.g., `tf.nn.dynamic_rnn`,
|
||||
`tf.nn.rnn`, and the classes in `tf.nn.rnn_cell`).
|
||||
* TensorBoard now has an Audio Dashboard, with associated audio summaries.
|
||||
|
@ -66,7 +66,6 @@ filegroup(
|
||||
"//tensorflow/cc:all_files",
|
||||
"//tensorflow/contrib:all_files",
|
||||
"//tensorflow/contrib/copy_graph:all_files",
|
||||
"//tensorflow/contrib/ctc:all_files",
|
||||
"//tensorflow/contrib/distributions:all_files",
|
||||
"//tensorflow/contrib/ffmpeg:all_files",
|
||||
"//tensorflow/contrib/ffmpeg/default:all_files",
|
||||
|
@ -15,7 +15,6 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/contrib/bayesflow:bayesflow_py",
|
||||
"//tensorflow/contrib/copy_graph:copy_graph_py",
|
||||
"//tensorflow/contrib/ctc:ctc_py",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
# Add projects here, they will show up under tf.contrib.
|
||||
from tensorflow.contrib import bayesflow
|
||||
from tensorflow.contrib import copy_graph
|
||||
from tensorflow.contrib import ctc
|
||||
from tensorflow.contrib import distributions
|
||||
from tensorflow.contrib import framework
|
||||
from tensorflow.contrib import grid_rnn
|
||||
|
@ -1,55 +0,0 @@
|
||||
# Description:
|
||||
# contains parts of TensorFlow that are experimental or unstable and which are not supported.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
|
||||
py_library(
|
||||
name = "ctc_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"ctc_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "ctc_decoder_ops_test",
|
||||
size = "small",
|
||||
srcs = ["ctc_decoder_ops_test.py"],
|
||||
additional_deps = [
|
||||
":ctc_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ctc_loss_op_test",
|
||||
size = "small",
|
||||
srcs = ["ctc_loss_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ctc_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,29 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Ops for CTC (Connectionist Temporal Classification).
|
||||
|
||||
@@ctc_loss
|
||||
@@ctc_greedy_decoder
|
||||
@@ctc_beam_search_decoder
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.ctc.ctc_ops import *
|
@ -0,0 +1,71 @@
|
||||
### `tf.nn.ctc_loss(inputs, labels, sequence_length, preprocess_collapse_repeated=False, ctc_merge_repeated=True)` {#ctc_loss}
|
||||
|
||||
Computes the CTC (Connectionist Temporal Classification) Loss.
|
||||
|
||||
See the article:
|
||||
|
||||
A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
|
||||
Connectionist Temporal Classification: Labelling Unsegmented Sequence Data
|
||||
with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, pp. 369-376.
|
||||
|
||||
Input requirements:
|
||||
|
||||
```
|
||||
sequence_length(b) <= time for all b
|
||||
|
||||
max(labels.indices(labels.indices[:, 1] == b, 2))
|
||||
<= sequence_length(b) for all b.
|
||||
```
|
||||
|
||||
Regarding the arguments `preprocess_collapse_repeated` and
|
||||
`ctc_merge_repeated`:
|
||||
|
||||
If `ctc_merge_repeated` is set False, then deep within the CTC calculation,
|
||||
repeated non-blank labels will not be merged and are interpreted
|
||||
as individual labels. This is a simplified (non-standard) version of CTC.
|
||||
|
||||
Here is a table of the (roughly) expected first order behavior:
|
||||
|
||||
* `preprocess_collapse_repeated=False, ctc_merge_repeated=True`
|
||||
|
||||
Classical CTC behavior: Outputs true repeated classes with nulls in
|
||||
between, and can also output repeated classes with no nulls in
|
||||
between that need to be collapsed by the decoder.
|
||||
|
||||
* `preprocess_collapse_repeated=True, ctc_merge_repeated=False`
|
||||
|
||||
Never learns repeated class of the same class under any circumstances.
|
||||
|
||||
* `preprocess_collapse_repeated=False, ctc_merge_repeated=False`
|
||||
|
||||
Outputs repeated classes with nulls in between, but generally does not
|
||||
require the decoder to collapse/merge repeated classes.
|
||||
|
||||
* `preprocess_collapse_repeated=True, ctc_merge_repeated=True`
|
||||
|
||||
Untested.
|
||||
```
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`inputs`</b>: 3-D `float` `Tensor` sized
|
||||
`[max_time x batch_size x num_classes]`. The logits.
|
||||
* <b>`labels`</b>: An `int32` `SparseTensor`.
|
||||
`labels.indices[i, :] == [b, t]` means `labels.values[i]` stores
|
||||
the id for (batch b, time t). See `core/ops/ctc_ops.cc` for more details.
|
||||
* <b>`sequence_length`</b>: 1-D `int32` vector, size `[batch_size]`.
|
||||
The sequence lengths.
|
||||
* <b>`preprocess_collapse_repeated`</b>: Boolean. Default: False.
|
||||
If True, repeated labels are collapsed prior to the CTC calculation.
|
||||
* <b>`ctc_merge_repeated`</b>: Boolean. Default: True.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A 1-D `float` `Tensor`, size `[batch]`, containing logits.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if labels is not a `SparseTensor`.
|
||||
|
@ -0,0 +1,36 @@
|
||||
### `tf.nn.ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, top_paths=1, merge_repeated=True)` {#ctc_beam_search_decoder}
|
||||
|
||||
Performs beam search decoding on the logits given in input.
|
||||
|
||||
If `merge_repeated` is `True`, merge repeated classes in output.
|
||||
This means that if consecutive entries in a beam are the same,
|
||||
only the first of these is emitted. That is, when the top path
|
||||
is `A B B B B`, `A B` is returned if `merge_repeated = True`
|
||||
but `A B B B B` is returned if `merge_repeated = False`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`inputs`</b>: 3-D `float` `Tensor`, size
|
||||
`[max_time x batch_size x num_classes]`. The logits.
|
||||
* <b>`sequence_length`</b>: 1-D `int32` vector containing sequence lengths,
|
||||
having size `[batch_size]`.
|
||||
* <b>`beam_width`</b>: An int scalar >= 0 (beam search beam width).
|
||||
* <b>`top_paths`</b>: An int scalar >= 0, <= beam_width (controls output size).
|
||||
* <b>`merge_repeated`</b>: Boolean. Default: True.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tuple `(decoded, log_probabilities)` where
|
||||
|
||||
* <b>`decoded`</b>: A list of length top_paths, where `decoded[j]`
|
||||
is a `SparseTensor` containing the decoded outputs:
|
||||
`decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
|
||||
The rows store: [batch, time].
|
||||
`decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
|
||||
The vector stores the decoded classes for beam j.
|
||||
`decoded[j].shape`: Shape vector, size `(2)`.
|
||||
The shape values are: `[batch_size, max_decoded_length[j]]`.
|
||||
* <b>`log_probability`</b>: A `float` matrix `(batch_size x top_paths)` containing
|
||||
sequence log-probabilities.
|
||||
|
@ -0,0 +1,38 @@
|
||||
### `tf.nn.ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True)` {#ctc_greedy_decoder}
|
||||
|
||||
Performs greedy decoding on the logits given in input (best path).
|
||||
|
||||
Note: Regardless of the value of merge_repeated, if the maximum index of a
|
||||
given time and batch corresponds to the blank index `(num_classes - 1)`, no
|
||||
new element is emitted.
|
||||
|
||||
If `merge_repeated` is `True`, merge repeated classes in output.
|
||||
This means that if consecutive logits' maximum indices are the same,
|
||||
only the first of these is emitted. Labeling the blank '*', the sequence
|
||||
`A B B * B B` becomes `A B` if `merge_repeated = True` and `A B B B B`
|
||||
if `merge_repeated = False`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`inputs`</b>: 3-D `float` `Tensor` sized
|
||||
`[max_time x batch_size x num_classes]`. The logits.
|
||||
* <b>`sequence_length`</b>: 1-D `int32` vector containing sequence lengths,
|
||||
having size `[batch_size]`.
|
||||
* <b>`merge_repeated`</b>: Boolean. Default: True.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tuple `(decoded, log_probabilities)` where
|
||||
|
||||
* <b>`decoded`</b>: A single-element list. `decoded[0]`
|
||||
is an `SparseTensor` containing the decoded outputs s.t.:
|
||||
`decoded.indices`: Indices matrix `(total_decoded_outputs x 2)`.
|
||||
The rows store: `[batch, time]`.
|
||||
`decoded.values`: Values vector, size `(total_decoded_outputs)`.
|
||||
The vector stores the decoded classes.
|
||||
`decoded.shape`: Shape vector, size `(2)`.
|
||||
The shape values are: `[batch_size, max_decoded_length]`
|
||||
* <b>`log_probability`</b>: A `float` matrix `(batch_size x 1)` containing sequence
|
||||
log-probabilities.
|
||||
|
@ -428,6 +428,9 @@
|
||||
* [`conv2d`](../../api_docs/python/nn.md#conv2d)
|
||||
* [`conv2d_transpose`](../../api_docs/python/nn.md#conv2d_transpose)
|
||||
* [`conv3d`](../../api_docs/python/nn.md#conv3d)
|
||||
* [`ctc_beam_search_decoder`](../../api_docs/python/nn.md#ctc_beam_search_decoder)
|
||||
* [`ctc_greedy_decoder`](../../api_docs/python/nn.md#ctc_greedy_decoder)
|
||||
* [`ctc_loss`](../../api_docs/python/nn.md#ctc_loss)
|
||||
* [`depthwise_conv2d`](../../api_docs/python/nn.md#depthwise_conv2d)
|
||||
* [`depthwise_conv2d_native`](../../api_docs/python/nn.md#depthwise_conv2d_native)
|
||||
* [`dilation2d`](../../api_docs/python/nn.md#dilation2d)
|
||||
|
@ -1692,6 +1692,163 @@ length(s) of the sequence(s) or completely unrolled if length(s) is not given.
|
||||
|
||||
|
||||
|
||||
## Conectionist Temporal Classification (CTC)
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.nn.ctc_loss(inputs, labels, sequence_length, preprocess_collapse_repeated=False, ctc_merge_repeated=True)` {#ctc_loss}
|
||||
|
||||
Computes the CTC (Connectionist Temporal Classification) Loss.
|
||||
|
||||
See the article:
|
||||
|
||||
A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
|
||||
Connectionist Temporal Classification: Labelling Unsegmented Sequence Data
|
||||
with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, pp. 369-376.
|
||||
|
||||
Input requirements:
|
||||
|
||||
```
|
||||
sequence_length(b) <= time for all b
|
||||
|
||||
max(labels.indices(labels.indices[:, 1] == b, 2))
|
||||
<= sequence_length(b) for all b.
|
||||
```
|
||||
|
||||
Regarding the arguments `preprocess_collapse_repeated` and
|
||||
`ctc_merge_repeated`:
|
||||
|
||||
If `ctc_merge_repeated` is set False, then deep within the CTC calculation,
|
||||
repeated non-blank labels will not be merged and are interpreted
|
||||
as individual labels. This is a simplified (non-standard) version of CTC.
|
||||
|
||||
Here is a table of the (roughly) expected first order behavior:
|
||||
|
||||
* `preprocess_collapse_repeated=False, ctc_merge_repeated=True`
|
||||
|
||||
Classical CTC behavior: Outputs true repeated classes with nulls in
|
||||
between, and can also output repeated classes with no nulls in
|
||||
between that need to be collapsed by the decoder.
|
||||
|
||||
* `preprocess_collapse_repeated=True, ctc_merge_repeated=False`
|
||||
|
||||
Never learns repeated class of the same class under any circumstances.
|
||||
|
||||
* `preprocess_collapse_repeated=False, ctc_merge_repeated=False`
|
||||
|
||||
Outputs repeated classes with nulls in between, but generally does not
|
||||
require the decoder to collapse/merge repeated classes.
|
||||
|
||||
* `preprocess_collapse_repeated=True, ctc_merge_repeated=True`
|
||||
|
||||
Untested.
|
||||
```
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`inputs`</b>: 3-D `float` `Tensor` sized
|
||||
`[max_time x batch_size x num_classes]`. The logits.
|
||||
* <b>`labels`</b>: An `int32` `SparseTensor`.
|
||||
`labels.indices[i, :] == [b, t]` means `labels.values[i]` stores
|
||||
the id for (batch b, time t). See `core/ops/ctc_ops.cc` for more details.
|
||||
* <b>`sequence_length`</b>: 1-D `int32` vector, size `[batch_size]`.
|
||||
The sequence lengths.
|
||||
* <b>`preprocess_collapse_repeated`</b>: Boolean. Default: False.
|
||||
If True, repeated labels are collapsed prior to the CTC calculation.
|
||||
* <b>`ctc_merge_repeated`</b>: Boolean. Default: True.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A 1-D `float` `Tensor`, size `[batch]`, containing logits.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if labels is not a `SparseTensor`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.nn.ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True)` {#ctc_greedy_decoder}
|
||||
|
||||
Performs greedy decoding on the logits given in input (best path).
|
||||
|
||||
Note: Regardless of the value of merge_repeated, if the maximum index of a
|
||||
given time and batch corresponds to the blank index `(num_classes - 1)`, no
|
||||
new element is emitted.
|
||||
|
||||
If `merge_repeated` is `True`, merge repeated classes in output.
|
||||
This means that if consecutive logits' maximum indices are the same,
|
||||
only the first of these is emitted. Labeling the blank '*', the sequence
|
||||
`A B B * B B` becomes `A B` if `merge_repeated = True` and `A B B B B`
|
||||
if `merge_repeated = False`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`inputs`</b>: 3-D `float` `Tensor` sized
|
||||
`[max_time x batch_size x num_classes]`. The logits.
|
||||
* <b>`sequence_length`</b>: 1-D `int32` vector containing sequence lengths,
|
||||
having size `[batch_size]`.
|
||||
* <b>`merge_repeated`</b>: Boolean. Default: True.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tuple `(decoded, log_probabilities)` where
|
||||
|
||||
* <b>`decoded`</b>: A single-element list. `decoded[0]`
|
||||
is an `SparseTensor` containing the decoded outputs s.t.:
|
||||
`decoded.indices`: Indices matrix `(total_decoded_outputs x 2)`.
|
||||
The rows store: `[batch, time]`.
|
||||
`decoded.values`: Values vector, size `(total_decoded_outputs)`.
|
||||
The vector stores the decoded classes.
|
||||
`decoded.shape`: Shape vector, size `(2)`.
|
||||
The shape values are: `[batch_size, max_decoded_length]`
|
||||
* <b>`log_probability`</b>: A `float` matrix `(batch_size x 1)` containing sequence
|
||||
log-probabilities.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.nn.ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, top_paths=1, merge_repeated=True)` {#ctc_beam_search_decoder}
|
||||
|
||||
Performs beam search decoding on the logits given in input.
|
||||
|
||||
If `merge_repeated` is `True`, merge repeated classes in output.
|
||||
This means that if consecutive entries in a beam are the same,
|
||||
only the first of these is emitted. That is, when the top path
|
||||
is `A B B B B`, `A B` is returned if `merge_repeated = True`
|
||||
but `A B B B B` is returned if `merge_repeated = False`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`inputs`</b>: 3-D `float` `Tensor`, size
|
||||
`[max_time x batch_size x num_classes]`. The logits.
|
||||
* <b>`sequence_length`</b>: 1-D `int32` vector containing sequence lengths,
|
||||
having size `[batch_size]`.
|
||||
* <b>`beam_width`</b>: An int scalar >= 0 (beam search beam width).
|
||||
* <b>`top_paths`</b>: An int scalar >= 0, <= beam_width (controls output size).
|
||||
* <b>`merge_repeated`</b>: Boolean. Default: True.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tuple `(decoded, log_probabilities)` where
|
||||
|
||||
* <b>`decoded`</b>: A list of length top_paths, where `decoded[j]`
|
||||
is a `SparseTensor` containing the decoded outputs:
|
||||
`decoded[j].indices`: Indices matrix `(total_decoded_outputs[j] x 2)`
|
||||
The rows store: [batch, time].
|
||||
`decoded[j].values`: Values vector, size `(total_decoded_outputs[j])`.
|
||||
The vector stores the decoded classes for beam j.
|
||||
`decoded[j].shape`: Shape vector, size `(2)`.
|
||||
The shape values are: `[batch_size, max_decoded_length[j]]`.
|
||||
* <b>`log_probability`</b>: A `float` matrix `(batch_size x top_paths)` containing
|
||||
sequence log-probabilities.
|
||||
|
||||
|
||||
|
||||
## Evaluation
|
||||
|
||||
The evaluation ops are useful for measuring the performance of a network.
|
||||
|
@ -776,6 +776,7 @@ py_library(
|
||||
"ops/clip_ops.py",
|
||||
"ops/control_flow_grad.py",
|
||||
"ops/control_flow_ops.py",
|
||||
"ops/ctc_ops.py",
|
||||
"ops/data_flow_grad.py",
|
||||
"ops/data_flow_ops.py",
|
||||
"ops/embedding_ops.py",
|
||||
|
@ -140,9 +140,7 @@ def all_libraries(module_to_name, members, documented):
|
||||
"softsign_grad", "xw_plus_b", "relu_layer",
|
||||
"lrn", "batch_norm_with_global_normalization",
|
||||
"batch_norm_with_global_normalization_grad",
|
||||
"all_candidate_sampler", "rnn",
|
||||
"state_saving_rnn", "bidirectional_rnn",
|
||||
"dynamic_rnn", "seq2seq", "rnn_cell"],
|
||||
"all_candidate_sampler", "seq2seq"],
|
||||
prefix=PREFIX_TEXT),
|
||||
library("rnn_cell", "Neural Network RNN Cells", tf.nn.rnn_cell),
|
||||
library("client", "Running Graphs", client_lib),
|
||||
|
@ -29,6 +29,8 @@ py_tests(
|
||||
"candidate_sampler_ops_test.py",
|
||||
"cholesky_op_test.py",
|
||||
"clip_ops_test.py",
|
||||
"ctc_decoder_ops_test.py",
|
||||
"ctc_loss_op_test.py",
|
||||
"decode_csv_op_test.py",
|
||||
"decode_png_op_test.py",
|
||||
"decode_raw_op_test.py",
|
||||
|
@ -143,7 +143,7 @@ class CTCGreedyDecoderTest(tf.test.TestCase):
|
||||
]
|
||||
|
||||
self._testCTCDecoder(
|
||||
tf.contrib.ctc.ctc_greedy_decoder,
|
||||
tf.nn.ctc_greedy_decoder,
|
||||
inputs, seq_lens, log_prob_truth, decode_truth)
|
||||
|
||||
def testCTCDecoderBeamSearch(self):
|
||||
@ -191,7 +191,7 @@ class CTCGreedyDecoderTest(tf.test.TestCase):
|
||||
]
|
||||
|
||||
self._testCTCDecoder(
|
||||
tf.contrib.ctc.ctc_beam_search_decoder,
|
||||
tf.nn.ctc_beam_search_decoder,
|
||||
inputs, seq_lens, log_prob_truth,
|
||||
decode_truth,
|
||||
beam_width=2,
|
@ -54,9 +54,9 @@ class CTCLossTest(tf.test.TestCase):
|
||||
inputs_t = tf.constant(inputs)
|
||||
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
loss = tf.contrib.ctc.ctc_loss(inputs=inputs_t,
|
||||
labels=labels,
|
||||
sequence_length=seq_lens)
|
||||
loss = tf.nn.ctc_loss(inputs=inputs_t,
|
||||
labels=labels,
|
||||
sequence_length=seq_lens)
|
||||
grad = tf.gradients(loss, [inputs_t])[0]
|
||||
|
||||
self.assertShapeEqual(loss_truth, loss)
|
@ -13,42 +13,74 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# pylint: disable=unused-import
|
||||
"""CTC (Connectionist Temporal Classification) Operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
||||
from tensorflow.python.ops import gen_ctc_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.nn_grad import _BroadcastMul
|
||||
|
||||
|
||||
# NOTE(ebrevdo): We redefine CTCLoss from gen_ctc_ops to only return
|
||||
# the first output. The second output is only used for the gradient.
|
||||
# pylint: disable=protected-access, invalid-name
|
||||
def ctc_loss(inputs, labels, sequence_length,
|
||||
preprocess_collapse_repeated=False, ctc_merge_repeated=True):
|
||||
"""Computes the CTC (Connectionist Temporal Classification) Loss.
|
||||
|
||||
Requires:
|
||||
```sequence_length(b) <= time for all b
|
||||
This op implements the CTC loss as presented in the article:
|
||||
|
||||
max(labels.indices(labels.indices[:, 1] == b, 2))
|
||||
<= sequence_length(b) for all b.```
|
||||
A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
|
||||
Connectionist Temporal Classification: Labelling Unsegmented Sequence Data
|
||||
with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, pp. 369-376.
|
||||
|
||||
If ctc_merge_repeated is set False, then *during* CTC calculation
|
||||
http://www.cs.toronto.edu/~graves/icml_2006.pdf
|
||||
|
||||
Input requirements:
|
||||
|
||||
```
|
||||
sequence_length(b) <= time for all b
|
||||
|
||||
max(labels.indices(labels.indices[:, 1] == b, 2))
|
||||
<= sequence_length(b) for all b.
|
||||
```
|
||||
|
||||
Regarding the arguments `preprocess_collapse_repeated` and
|
||||
`ctc_merge_repeated`:
|
||||
|
||||
If `preprocess_collapse_repeated` is True, then a preprocessing step runs
|
||||
before loss calculation, wherein repeated labels passed to the loss
|
||||
are merged into single labels. This is useful if the training labels come
|
||||
from, e.g., forced alignments and therefore have unnecessary repetitions.
|
||||
|
||||
If `ctc_merge_repeated` is set False, then deep within the CTC calculation,
|
||||
repeated non-blank labels will not be merged and are interpreted
|
||||
as individual labels. This is a simplified version of CTC.
|
||||
as individual labels. This is a simplified (non-standard) version of CTC.
|
||||
|
||||
Here is a table of the (roughly) expected first order behavior:
|
||||
|
||||
* `preprocess_collapse_repeated=False`, `ctc_merge_repeated=True`
|
||||
|
||||
Classical CTC behavior: Outputs true repeated classes with blanks in
|
||||
between, and can also output repeated classes with no blanks in
|
||||
between that need to be collapsed by the decoder.
|
||||
|
||||
* `preprocess_collapse_repeated=True`, `ctc_merge_repeated=False`
|
||||
|
||||
Never learns to output repeated classes, as they are collapsed
|
||||
in the input labels before training.
|
||||
|
||||
* `preprocess_collapse_repeated=False`, `ctc_merge_repeated=False`
|
||||
|
||||
Outputs repeated classes with blanks in between, but generally does not
|
||||
require the decoder to collapse/merge repeated classes.
|
||||
|
||||
* `preprocess_collapse_repeated=True`, `ctc_merge_repeated=True`
|
||||
|
||||
Untested. Very likely will not learn to output repeated classes.
|
||||
|
||||
Args:
|
||||
inputs: 3-D `float` `Tensor` sized
|
||||
@ -62,11 +94,9 @@ def ctc_loss(inputs, labels, sequence_length,
|
||||
If True, repeated labels are collapsed prior to the CTC calculation.
|
||||
ctc_merge_repeated: Boolean. Default: True.
|
||||
|
||||
|
||||
Returns:
|
||||
A 1-D `float` `Tensor`, size `[batch]`, containing logits.
|
||||
|
||||
|
||||
Raises:
|
||||
TypeError: if labels is not a `SparseTensor`.
|
||||
"""
|
||||
@ -129,12 +159,13 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
||||
given time and batch corresponds to the blank index `(num_classes - 1)`, no
|
||||
new element is emitted.
|
||||
|
||||
If merge_repeated is `True`, merge repeated classes in output.
|
||||
If `merge_repeated` is `True`, merge repeated classes in output.
|
||||
This means that if consecutive logits' maximum indices are the same,
|
||||
only the first of these is emitted. Labeling the blank '*', the sequence
|
||||
"A B B * B B" becomes "A B" if `merge_repeated = True` and "A B B B B"
|
||||
if `merge_repeated = False`.
|
||||
only the first of these is emitted. The sequence `A B B * B * B` (where '*'
|
||||
is the blank label) becomes
|
||||
|
||||
* `A B` if `merge_repeated=True`.
|
||||
* `A B B B B B` if `merge_repeated=False`.
|
||||
|
||||
Args:
|
||||
inputs: 3-D `float` `Tensor` sized
|
||||
@ -143,7 +174,6 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
||||
having size `[batch_size]`.
|
||||
merge_repeated: Boolean. Default: True.
|
||||
|
||||
|
||||
Returns:
|
||||
A tuple `(decoded, log_probabilities)` where
|
||||
decoded: A single-element list. `decoded[0]`
|
||||
@ -184,12 +214,17 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
|
||||
top_paths=1, merge_repeated=True):
|
||||
"""Performs beam search decoding on the logits given in input.
|
||||
|
||||
If merge_repeated is `True`, merge repeated classes in output.
|
||||
**Note** The `ctc_greedy_decoder` is a special case of the
|
||||
`ctc_beam_search_decoder` with `top_paths=1` (but that decoder is faster
|
||||
for this special case).
|
||||
|
||||
If `merge_repeated` is `True`, merge repeated classes in the output beams.
|
||||
This means that if consecutive entries in a beam are the same,
|
||||
only the first of these is emitted. That is, when the top path
|
||||
is "A B B B B", "A B" is returned if `merge_repeated = True`
|
||||
but "A B B B B" is returned if `merge_repeated = False`.
|
||||
is `A B B B B`, the return value is:
|
||||
|
||||
* `A B` if `merge_repeated = True`.
|
||||
* `A B B B B` if `merge_repeated = False`.
|
||||
|
||||
Args:
|
||||
inputs: 3-D `float` `Tensor`, size
|
||||
@ -200,7 +235,6 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
|
||||
top_paths: An int scalar >= 0, <= beam_width (controls output size).
|
||||
merge_repeated: Boolean. Default: True.
|
||||
|
||||
|
||||
Returns:
|
||||
A tuple `(decoded, log_probabilities)` where
|
||||
decoded: A list of length top_paths, where `decoded[j]`
|
@ -220,6 +220,12 @@ Neural Networks. Most accept an `RNNCell`-subclassed object
|
||||
@@state_saving_rnn
|
||||
@@bidirectional_rnn
|
||||
|
||||
## Conectionist Temporal Classification (CTC)
|
||||
|
||||
@@ctc_loss
|
||||
@@ctc_greedy_decoder
|
||||
@@ctc_beam_search_decoder
|
||||
|
||||
## Evaluation
|
||||
|
||||
The evaluation ops are useful for measuring the performance of a network.
|
||||
@ -294,6 +300,7 @@ from tensorflow.python.util.all_util import make_all
|
||||
# Bring more nn-associated functionality into this package.
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.ctc_ops import *
|
||||
from tensorflow.python.ops.nn_ops import *
|
||||
from tensorflow.python.ops.candidate_sampling_ops import *
|
||||
from tensorflow.python.ops.embedding_ops import *
|
||||
@ -1186,14 +1193,10 @@ __all__.extend([
|
||||
"all_candidate_sampler",
|
||||
"batch_norm_with_global_normalization",
|
||||
"batch_normalization",
|
||||
"bidirectional_rnn",
|
||||
"conv2d_backprop_filter",
|
||||
"conv2d_backprop_input",
|
||||
"depthwise_conv2d_native",
|
||||
"dynamic_rnn",
|
||||
"lrn",
|
||||
"relu_layer",
|
||||
"rnn",
|
||||
"state_saving_rnn",
|
||||
"xw_plus_b",
|
||||
])
|
||||
|
Loading…
Reference in New Issue
Block a user