STT-tensorflow/tensorflow/contrib/crf/python/ops/crf.py
Yifei Feng 22f3a97b8b Merge changes from github.
PiperOrigin-RevId: 194031845
2018-04-23 21:21:38 -07:00

520 lines
20 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for constructing a linear-chain CRF.
The following snippet is an example of a CRF layer on top of a batched sequence
of unary scores (logits for every word). This example also decodes the most
likely sequence at test time. There are two ways to do decoding. One
is using crf_decode to do decoding in Tensorflow , and the other one is using
viterbi_decode in Numpy.
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
unary_scores, gold_tags, sequence_lengths)
loss = tf.reduce_mean(-log_likelihood)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# Decoding in Tensorflow.
viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
unary_scores, transition_params, sequence_lengths)
tf_viterbi_sequence, tf_viterbi_score, _ = session.run(
[viterbi_sequence, viterbi_score, train_op])
# Decoding in Numpy.
tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
[unary_scores, sequence_lengths, transition_params, train_op])
for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
tf_sequence_lengths):
# Remove padding.
tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
# Compute the highest score and its tag sequence.
tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode(
tf_unary_scores_, tf_transition_params)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope as vs
__all__ = [
"crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
"viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
"CrfDecodeBackwardRnnCell"
]
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params):
"""Computes the unnormalized score for a tag sequence.
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
compute the unnormalized score.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix.
Returns:
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
"""
# If max_seq_len is 1, we skip the score calculation and simply gather the
# unary potentials of the single tag.
def _single_seq_fn():
batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
example_inds = array_ops.reshape(
math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
return array_ops.gather_nd(
array_ops.squeeze(inputs, [1]),
array_ops.concat([example_inds, tag_indices], axis=1))
def _multi_seq_fn():
# Compute the scores of the given tag sequence.
unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
binary_scores = crf_binary_score(tag_indices, sequence_lengths,
transition_params)
sequence_scores = unary_scores + binary_scores
return sequence_scores
return utils.smart_cond(
pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
1),
true_fn=_single_seq_fn,
false_fn=_multi_seq_fn)
def crf_log_norm(inputs, sequence_lengths, transition_params):
"""Computes the normalization for a CRF.
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix.
Returns:
log_norm: A [batch_size] vector of normalizers for a CRF.
"""
# Split up the first and rest of the inputs in preparation for the forward
# algorithm.
first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
first_input = array_ops.squeeze(first_input, [1])
# If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over
# the "initial state" (the unary potentials).
def _single_seq_fn():
return math_ops.reduce_logsumexp(first_input, [1])
def _multi_seq_fn():
"""Forward computation of alpha values."""
rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
# Compute the alpha values in the forward algorithm in order to get the
# partition function.
forward_cell = CrfForwardRnnCell(transition_params)
_, alphas = rnn.dynamic_rnn(
cell=forward_cell,
inputs=rest_of_input,
sequence_length=sequence_lengths - 1,
initial_state=first_input,
dtype=dtypes.float32)
log_norm = math_ops.reduce_logsumexp(alphas, [1])
return log_norm
max_seq_len = array_ops.shape(inputs)[1]
return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1),
true_fn=_single_seq_fn,
false_fn=_multi_seq_fn)
def crf_log_likelihood(inputs,
tag_indices,
sequence_lengths,
transition_params=None):
"""Computes the log-likelihood of tag sequences in a CRF.
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
compute the log-likelihood.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix, if available.
Returns:
log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
each example, given the sequence of tag indices.
transition_params: A [num_tags, num_tags] transition matrix. This is either
provided by the caller or created in this function.
"""
# Get shape information.
num_tags = inputs.get_shape()[2].value
# Get the transition matrix if not provided.
if transition_params is None:
transition_params = vs.get_variable("transitions", [num_tags, num_tags])
sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params)
log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
# Normalize the scores to get the log-likelihood per example.
log_likelihood = sequence_scores - log_norm
return log_likelihood, transition_params
def crf_unary_score(tag_indices, sequence_lengths, inputs):
"""Computes the unary scores of tag sequences.
Args:
tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
sequence_lengths: A [batch_size] vector of true sequence lengths.
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
Returns:
unary_scores: A [batch_size] vector of unary scores.
"""
batch_size = array_ops.shape(inputs)[0]
max_seq_len = array_ops.shape(inputs)[1]
num_tags = array_ops.shape(inputs)[2]
flattened_inputs = array_ops.reshape(inputs, [-1])
offsets = array_ops.expand_dims(
math_ops.range(batch_size) * max_seq_len * num_tags, 1)
offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
# Use int32 or int64 based on tag_indices' dtype.
if tag_indices.dtype == dtypes.int64:
offsets = math_ops.to_int64(offsets)
flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])
unary_scores = array_ops.reshape(
array_ops.gather(flattened_inputs, flattened_tag_indices),
[batch_size, max_seq_len])
masks = array_ops.sequence_mask(sequence_lengths,
maxlen=array_ops.shape(tag_indices)[1],
dtype=dtypes.float32)
unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
return unary_scores
def crf_binary_score(tag_indices, sequence_lengths, transition_params):
"""Computes the binary scores of tag sequences.
Args:
tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] matrix of binary potentials.
Returns:
binary_scores: A [batch_size] vector of binary scores.
"""
# Get shape information.
num_tags = transition_params.get_shape()[0]
num_transitions = array_ops.shape(tag_indices)[1] - 1
# Truncate by one on each side of the sequence to get the start and end
# indices of each transition.
start_tag_indices = array_ops.slice(tag_indices, [0, 0],
[-1, num_transitions])
end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])
# Encode the indices in a flattened representation.
flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
flattened_transition_params = array_ops.reshape(transition_params, [-1])
# Get the binary scores based on the flattened representation.
binary_scores = array_ops.gather(flattened_transition_params,
flattened_transition_indices)
masks = array_ops.sequence_mask(sequence_lengths,
maxlen=array_ops.shape(tag_indices)[1],
dtype=dtypes.float32)
truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
return binary_scores
class CrfForwardRnnCell(rnn_cell.RNNCell):
"""Computes the alpha values in a linear-chain CRF.
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
"""
def __init__(self, transition_params):
"""Initialize the CrfForwardRnnCell.
Args:
transition_params: A [num_tags, num_tags] matrix of binary potentials.
This matrix is expanded into a [1, num_tags, num_tags] in preparation
for the broadcast summation occurring within the cell.
"""
self._transition_params = array_ops.expand_dims(transition_params, 0)
self._num_tags = transition_params.get_shape()[0].value
@property
def state_size(self):
return self._num_tags
@property
def output_size(self):
return self._num_tags
def __call__(self, inputs, state, scope=None):
"""Build the CrfForwardRnnCell.
Args:
inputs: A [batch_size, num_tags] matrix of unary potentials.
state: A [batch_size, num_tags] matrix containing the previous alpha
values.
scope: Unused variable scope of this cell.
Returns:
new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices
values containing the new alpha values.
"""
state = array_ops.expand_dims(state, 2)
# This addition op broadcasts self._transitions_params along the zeroth
# dimension and state along the second dimension. This performs the
# multiplication of previous alpha values and the current binary potentials
# in log space.
transition_scores = state + self._transition_params
new_alphas = inputs + math_ops.reduce_logsumexp(transition_scores, [1])
# Both the state and the output of this RNN cell contain the alphas values.
# The output value is currently unused and simply satisfies the RNN API.
# This could be useful in the future if we need to compute marginal
# probabilities, which would require the accumulated alpha values at every
# time step.
return new_alphas, new_alphas
def viterbi_decode(score, transition_params):
"""Decode the highest scoring sequence of tags outside of TensorFlow.
This should only be used at test time.
Args:
score: A [seq_len, num_tags] matrix of unary potentials.
transition_params: A [num_tags, num_tags] matrix of binary potentials.
Returns:
viterbi: A [seq_len] list of integers containing the highest scoring tag
indices.
viterbi_score: A float containing the score for the Viterbi sequence.
"""
trellis = np.zeros_like(score)
backpointers = np.zeros_like(score, dtype=np.int32)
trellis[0] = score[0]
for t in range(1, score.shape[0]):
v = np.expand_dims(trellis[t - 1], 1) + transition_params
trellis[t] = score[t] + np.max(v, 0)
backpointers[t] = np.argmax(v, 0)
viterbi = [np.argmax(trellis[-1])]
for bp in reversed(backpointers[1:]):
viterbi.append(bp[viterbi[-1]])
viterbi.reverse()
viterbi_score = np.max(trellis[-1])
return viterbi, viterbi_score
class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
"""Computes the forward decoding in a linear-chain CRF.
"""
def __init__(self, transition_params):
"""Initialize the CrfDecodeForwardRnnCell.
Args:
transition_params: A [num_tags, num_tags] matrix of binary
potentials. This matrix is expanded into a
[1, num_tags, num_tags] in preparation for the broadcast
summation occurring within the cell.
"""
self._transition_params = array_ops.expand_dims(transition_params, 0)
self._num_tags = transition_params.get_shape()[0].value
@property
def state_size(self):
return self._num_tags
@property
def output_size(self):
return self._num_tags
def __call__(self, inputs, state, scope=None):
"""Build the CrfDecodeForwardRnnCell.
Args:
inputs: A [batch_size, num_tags] matrix of unary potentials.
state: A [batch_size, num_tags] matrix containing the previous step's
score values.
scope: Unused variable scope of this cell.
Returns:
backpointers: A [batch_size, num_tags] matrix of backpointers.
new_state: A [batch_size, num_tags] matrix of new score values.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
state = array_ops.expand_dims(state, 2) # [B, O, 1]
# This addition op broadcasts self._transitions_params along the zeroth
# dimension and state along the second dimension.
# [B, O, 1] + [1, O, O] -> [B, O, O]
transition_scores = state + self._transition_params # [B, O, O]
new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O]
backpointers = math_ops.argmax(transition_scores, 1)
backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O]
return backpointers, new_state
class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
"""Computes backward decoding in a linear-chain CRF.
"""
def __init__(self, num_tags):
"""Initialize the CrfDecodeBackwardRnnCell.
Args:
num_tags: An integer. The number of tags.
"""
self._num_tags = num_tags
@property
def state_size(self):
return 1
@property
def output_size(self):
return 1
def __call__(self, inputs, state, scope=None):
"""Build the CrfDecodeBackwardRnnCell.
Args:
inputs: A [batch_size, num_tags] matrix of
backpointer of next step (in time order).
state: A [batch_size, 1] matrix of tag index of next step.
scope: Unused variable scope of this cell.
Returns:
new_tags, new_tags: A pair of [batch_size, num_tags]
tensors containing the new tag indices.
"""
state = array_ops.squeeze(state, axis=[1]) # [B]
batch_size = array_ops.shape(inputs)[0]
b_indices = math_ops.range(batch_size) # [B]
indices = array_ops.stack([b_indices, state], axis=1) # [B, 2]
new_tags = array_ops.expand_dims(
gen_array_ops.gather_nd(inputs, indices), # [B]
axis=-1) # [B, 1]
return new_tags, new_tags
def crf_decode(potentials, transition_params, sequence_length):
"""Decode the highest scoring sequence of tags in TensorFlow.
This is a function for tensor.
Args:
potentials: A [batch_size, max_seq_len, num_tags] tensor of
unary potentials.
transition_params: A [num_tags, num_tags] matrix of
binary potentials.
sequence_length: A [batch_size] vector of true sequence lengths.
Returns:
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
Contains the highest scoring tag indices.
best_score: A [batch_size] vector, containing the score of `decode_tags`.
"""
# If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
# and the max activation.
def _single_seq_fn():
squeezed_potentials = array_ops.squeeze(potentials, [1])
decode_tags = array_ops.expand_dims(
math_ops.argmax(squeezed_potentials, axis=1), 1)
best_score = math_ops.reduce_max(squeezed_potentials, axis=1)
return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score
def _multi_seq_fn():
"""Decoding of highest scoring sequence."""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
num_tags = potentials.get_shape()[2].value
# Computes forward decoding. Get last score and backpointers.
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
# sequence length is not allowed to be less than zero
sequence_length_less_one = math_ops.maximum(0, sequence_length - 1)
backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O]
crf_fwd_cell,
inputs=inputs,
sequence_length=sequence_length_less_one,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32)
backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O]
backpointers, sequence_length_less_one, seq_dim=1)
# Computes backward decoding. Extract tag indices from backpointers.
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B]
dtype=dtypes.int32)
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1]
crf_bwd_cell,
inputs=backpointers,
sequence_length=sequence_length_less_one,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32)
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T]
axis=1)
decode_tags = gen_array_ops.reverse_sequence( # [B, T]
decode_tags, sequence_length, seq_dim=1)
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
return decode_tags, best_score
return utils.smart_cond(
pred=math_ops.equal(potentials.shape[1].value or
array_ops.shape(potentials)[1], 1),
true_fn=_single_seq_fn,
false_fn=_multi_seq_fn)