430 lines
18 KiB
Python
430 lines
18 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""TfLite LSTMCell wrapper.
|
|
|
|
TODO(renjieliu): Find a better home for this one.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.lite.python.op_hint import OpHint
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import control_flow_util
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import rnn_cell_impl
|
|
from tensorflow.python.ops import variable_scope as vs
|
|
from tensorflow.python.ops.rnn import _best_effort_input_batch_size
|
|
from tensorflow.python.ops.rnn import _dynamic_rnn_loop
|
|
from tensorflow.python.ops.rnn import _should_cache
|
|
from tensorflow.python.ops.rnn import _transpose_batch_time
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@tf_export(v1=["lite.experimental.nn.dynamic_rnn"])
|
|
def dynamic_rnn(cell,
|
|
inputs,
|
|
sequence_length=None,
|
|
initial_state=None,
|
|
dtype=None,
|
|
parallel_iterations=None,
|
|
swap_memory=False,
|
|
time_major=True,
|
|
scope=None):
|
|
"""Creates a recurrent neural network specified by RNNCell `cell`.
|
|
|
|
Performs fully dynamic unrolling of `inputs`.
|
|
|
|
Example:
|
|
|
|
```python
|
|
# create a BasicRNNCell
|
|
rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size)
|
|
|
|
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
|
|
|
|
# defining initial state
|
|
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
|
|
|
|
# 'state' is a tensor of shape [batch_size, cell_state_size]
|
|
outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data,
|
|
initial_state=initial_state,
|
|
dtype=tf.float32)
|
|
```
|
|
|
|
```python
|
|
# create 2 LSTMCells
|
|
rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
|
|
|
|
# create a RNN cell composed sequentially of a number of RNNCells
|
|
multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers)
|
|
|
|
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
|
|
# 'state' is a N-tuple where N is the number of LSTMCells containing a
|
|
# tf.nn.rnn_cell.LSTMStateTuple for each cell
|
|
outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell,
|
|
inputs=data,
|
|
dtype=tf.float32)
|
|
```
|
|
|
|
|
|
Args:
|
|
cell: An instance of RNNCell.
|
|
inputs: The RNN inputs.
|
|
If `time_major == False` (default), this must be a `Tensor` of shape:
|
|
`[batch_size, max_time, ...]`, or a nested tuple of such elements.
|
|
If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
|
|
batch_size, ...]`, or a nested tuple of such elements. This may also be
|
|
a (possibly nested) tuple of Tensors satisfying this property. The
|
|
first two dimensions must match across all the inputs, but otherwise the
|
|
ranks and other shape components may differ. In this case, input to
|
|
`cell` at each time-step will replicate the structure of these tuples,
|
|
except for the time dimension (from which the time is taken). The input
|
|
to `cell` at each time step will be a `Tensor` or (possibly nested)
|
|
tuple of Tensors each with dimensions `[batch_size, ...]`.
|
|
sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
|
|
to copy-through state and zero-out outputs when past a batch element's
|
|
sequence length. So it's more for performance than correctness.
|
|
initial_state: (optional) An initial state for the RNN. If `cell.state_size`
|
|
is an integer, this must be a `Tensor` of appropriate type and shape
|
|
`[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
|
|
should be a tuple of tensors having shapes `[batch_size, s] for s in
|
|
cell.state_size`.
|
|
dtype: (optional) The data type for the initial state and expected output.
|
|
Required if initial_state is not provided or RNN state has a heterogeneous
|
|
dtype.
|
|
parallel_iterations: (Default: 32). The number of iterations to run in
|
|
parallel. Those operations which do not have any temporal dependency and
|
|
can be run in parallel, will be. This parameter trades off time for
|
|
space. Values >> 1 use more memory but take less time, while smaller
|
|
values use less memory but computations take longer.
|
|
swap_memory: Transparently swap the tensors produced in forward inference
|
|
but needed for back prop from GPU to CPU. This allows training RNNs which
|
|
would typically not fit on a single GPU, with very minimal (or no)
|
|
performance penalty.
|
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
|
|
these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
|
|
these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
|
|
`time_major = True` is a bit more efficient because it avoids transposes
|
|
at the beginning and end of the RNN calculation. However, most TensorFlow
|
|
data is batch-major, so by default this function accepts input and emits
|
|
output in batch-major form.
|
|
scope: VariableScope for the created subgraph; defaults to "rnn".
|
|
|
|
Returns:
|
|
A pair (outputs, state) where:
|
|
|
|
outputs: The RNN output `Tensor`.
|
|
|
|
If time_major == False (default), this will be a `Tensor` shaped:
|
|
`[batch_size, max_time, cell.output_size]`.
|
|
|
|
If time_major == True, this will be a `Tensor` shaped:
|
|
`[max_time, batch_size, cell.output_size]`.
|
|
|
|
Note, if `cell.output_size` is a (possibly nested) tuple of integers
|
|
or `TensorShape` objects, then `outputs` will be a tuple having the
|
|
same structure as `cell.output_size`, containing Tensors having shapes
|
|
corresponding to the shape data in `cell.output_size`.
|
|
|
|
state: The final state. If `cell.state_size` is an int, this
|
|
will be shaped `[batch_size, cell.state_size]`. If it is a
|
|
`TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
|
|
If it is a (possibly nested) tuple of ints or `TensorShape`, this will
|
|
be a tuple having the corresponding shapes. If cells are `LSTMCells`
|
|
`state` will be a tuple containing a `LSTMStateTuple` for each cell.
|
|
|
|
Raises:
|
|
TypeError: If `cell` is not an instance of RNNCell.
|
|
ValueError: If inputs is None or an empty list.
|
|
RuntimeError: If not using control flow v2.
|
|
"""
|
|
|
|
# Currently only support time_major == True case.
|
|
assert time_major
|
|
|
|
# TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or
|
|
# TfLiteRNNCells.
|
|
rnn_cell_impl.assert_like_rnncell("cell", cell)
|
|
|
|
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
|
raise RuntimeError("OpHint dynamic rnn only supports control flow v2.")
|
|
|
|
parent_first_child_input = [{
|
|
"parent_ophint_input_index": 0,
|
|
"first_child_ophint_input_index": 0
|
|
}]
|
|
parent_last_child_output = [{
|
|
"parent_output_index": 0,
|
|
# For LstmCell, the index is 2.
|
|
# For RnnCell, the index is 1.
|
|
# So we use -1 meaning it's the last one.
|
|
"child_output_index": -1
|
|
}]
|
|
internal_children_input_output = [{
|
|
"child_input_index": 0,
|
|
# For LstmCell, the index is 2.
|
|
# For RnnCell, the index is 1.
|
|
# So we use -1 meaning it's the last one.
|
|
"child_output_index": -1
|
|
}]
|
|
inputs_outputs_mappings = {
|
|
"parent_first_child_input": parent_first_child_input,
|
|
"parent_last_child_output": parent_last_child_output,
|
|
"internal_children_input_output": internal_children_input_output
|
|
}
|
|
tflite_wrapper = OpHint(
|
|
"TfLiteDynamicRnn",
|
|
level=2,
|
|
children_inputs_mappings=inputs_outputs_mappings)
|
|
with vs.variable_scope(scope or "rnn") as varscope:
|
|
# Create a new scope in which the caching device is either
|
|
# determined by the parent scope, or is set to place the cached
|
|
# Variable using the same placement as for the rest of the RNN.
|
|
if _should_cache():
|
|
if varscope.caching_device is None:
|
|
varscope.set_caching_device(lambda op: op.device)
|
|
|
|
inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0)
|
|
|
|
# By default, time_major==False and inputs are batch-major: shaped
|
|
# [batch, time, depth]
|
|
# For internal calculations, we transpose to [time, batch, depth]
|
|
flat_input = nest.flatten(inputs)
|
|
|
|
if not time_major:
|
|
# (batch, time, depth) => (time, batch, depth)
|
|
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
|
|
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
|
|
|
|
parallel_iterations = parallel_iterations or 32
|
|
if sequence_length is not None:
|
|
sequence_length = math_ops.cast(sequence_length, dtypes.int32)
|
|
if sequence_length.shape.rank not in (None, 1):
|
|
raise ValueError(
|
|
"sequence_length must be a vector of length batch_size, "
|
|
"but saw shape: %s" % sequence_length.shape)
|
|
sequence_length = array_ops.identity( # Just to find it in the graph.
|
|
sequence_length,
|
|
name="sequence_length")
|
|
|
|
batch_size = _best_effort_input_batch_size(flat_input)
|
|
|
|
if initial_state is not None:
|
|
state = initial_state
|
|
else:
|
|
if not dtype:
|
|
raise ValueError("If there is no initial_state, you must give a dtype.")
|
|
if getattr(cell, "get_initial_state", None) is not None:
|
|
state = cell.get_initial_state(
|
|
inputs=None, batch_size=batch_size, dtype=dtype)
|
|
else:
|
|
state = cell.zero_state(batch_size, dtype)
|
|
|
|
def _assert_has_shape(x, shape):
|
|
x_shape = array_ops.shape(x)
|
|
packed_shape = array_ops.stack(shape)
|
|
return control_flow_ops.Assert(
|
|
math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
|
|
"Expected shape for Tensor %s is " % x.name, packed_shape,
|
|
" but saw shape: ", x_shape
|
|
])
|
|
|
|
if not context.executing_eagerly() and sequence_length is not None:
|
|
# Perform some shape validation
|
|
with ops.control_dependencies(
|
|
[_assert_has_shape(sequence_length, [batch_size])]):
|
|
sequence_length = array_ops.identity(
|
|
sequence_length, name="CheckSeqLen")
|
|
|
|
inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
|
|
|
|
outputs, final_state = _dynamic_rnn_loop(
|
|
cell,
|
|
inputs,
|
|
state,
|
|
parallel_iterations=parallel_iterations,
|
|
swap_memory=swap_memory,
|
|
sequence_length=sequence_length,
|
|
dtype=dtype)
|
|
|
|
# Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
|
|
# If we are performing batch-major calculations, transpose output back
|
|
# to shape [batch, time, depth]
|
|
if not time_major:
|
|
# (time, batch, depth) => (batch, time, depth)
|
|
outputs = nest.map_structure(_transpose_batch_time, outputs)
|
|
outputs = tflite_wrapper.add_output(outputs, name="outputs")
|
|
|
|
return outputs, final_state
|
|
|
|
|
|
def bidirectional_dynamic_rnn(cell_fw,
|
|
cell_bw,
|
|
inputs,
|
|
sequence_length=None,
|
|
initial_state_fw=None,
|
|
initial_state_bw=None,
|
|
dtype=None,
|
|
parallel_iterations=None,
|
|
swap_memory=False,
|
|
time_major=False,
|
|
scope=None):
|
|
"""Creates a dynamic version of bidirectional recurrent neural network.
|
|
|
|
Takes input and builds independent forward and backward RNNs. The input_size
|
|
of forward and backward cell must match. The initial state for both directions
|
|
is zero by default (but can be set optionally) and no intermediate states are
|
|
ever returned -- the network is fully unrolled for the given (passed in)
|
|
length(s) of the sequence(s) or completely unrolled if length(s) is not
|
|
given.
|
|
|
|
Args:
|
|
cell_fw: An instance of RNNCell, to be used for forward direction.
|
|
cell_bw: An instance of RNNCell, to be used for backward direction.
|
|
inputs: The RNN inputs.
|
|
If time_major == False (default), this must be a tensor of shape:
|
|
`[batch_size, max_time, ...]`, or a nested tuple of such elements.
|
|
If time_major == True, this must be a tensor of shape: `[max_time,
|
|
batch_size, ...]`, or a nested tuple of such elements.
|
|
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
|
|
containing the actual lengths for each of the sequences in the batch. If
|
|
not provided, all batch entries are assumed to be full sequences; and time
|
|
reversal is applied from time `0` to `max_time` for each sequence.
|
|
initial_state_fw: (optional) An initial state for the forward RNN. This must
|
|
be a tensor of appropriate type and shape `[batch_size,
|
|
cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
|
|
tuple of tensors having shapes `[batch_size, s] for s in
|
|
cell_fw.state_size`.
|
|
initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
|
|
corresponding properties of `cell_bw`.
|
|
dtype: (optional) The data type for the initial states and expected output.
|
|
Required if initial_states are not provided or RNN states have a
|
|
heterogeneous dtype.
|
|
parallel_iterations: (Default: 32). The number of iterations to run in
|
|
parallel. Those operations which do not have any temporal dependency and
|
|
can be run in parallel, will be. This parameter trades off time for
|
|
space. Values >> 1 use more memory but take less time, while smaller
|
|
values use less memory but computations take longer.
|
|
swap_memory: Transparently swap the tensors produced in forward inference
|
|
but needed for back prop from GPU to CPU. This allows training RNNs which
|
|
would typically not fit on a single GPU, with very minimal (or no)
|
|
performance penalty.
|
|
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
|
|
these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
|
|
these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
|
|
`time_major = True` is a bit more efficient because it avoids transposes
|
|
at the beginning and end of the RNN calculation. However, most TensorFlow
|
|
data is batch-major, so by default this function accepts input and emits
|
|
output in batch-major form.
|
|
scope: VariableScope for the created subgraph; defaults to
|
|
"bidirectional_rnn"
|
|
|
|
Returns:
|
|
A tuple (outputs, output_states) where:
|
|
outputs: A tuple (output_fw, output_bw) containing the forward and
|
|
the backward rnn output `Tensor`.
|
|
If time_major == False (default),
|
|
output_fw will be a `Tensor` shaped:
|
|
`[batch_size, max_time, cell_fw.output_size]`
|
|
and output_bw will be a `Tensor` shaped:
|
|
`[batch_size, max_time, cell_bw.output_size]`.
|
|
If time_major == True,
|
|
output_fw will be a `Tensor` shaped:
|
|
`[max_time, batch_size, cell_fw.output_size]`
|
|
and output_bw will be a `Tensor` shaped:
|
|
`[max_time, batch_size, cell_bw.output_size]`.
|
|
It returns a tuple instead of a single concatenated `Tensor`, unlike
|
|
in the `bidirectional_rnn`. If the concatenated one is preferred,
|
|
the forward and backward outputs can be concatenated as
|
|
`tf.concat(outputs, 2)`.
|
|
output_states: A tuple (output_state_fw, output_state_bw) containing
|
|
the forward and the backward final states of bidirectional rnn.
|
|
|
|
Raises:
|
|
TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
|
|
"""
|
|
rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
|
|
rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
|
|
|
|
with vs.variable_scope(scope or "bidirectional_rnn"):
|
|
# Forward direction
|
|
with vs.variable_scope("fw") as fw_scope:
|
|
output_fw, output_state_fw = dynamic_rnn(
|
|
cell=cell_fw,
|
|
inputs=inputs,
|
|
sequence_length=sequence_length,
|
|
initial_state=initial_state_fw,
|
|
dtype=dtype,
|
|
parallel_iterations=parallel_iterations,
|
|
swap_memory=swap_memory,
|
|
time_major=time_major,
|
|
scope=fw_scope)
|
|
|
|
# Backward direction
|
|
if not time_major:
|
|
time_axis = 1
|
|
batch_axis = 0
|
|
else:
|
|
time_axis = 0
|
|
batch_axis = 1
|
|
|
|
def _reverse(input_, seq_lengths, seq_axis, batch_axis):
|
|
if seq_lengths is not None:
|
|
return array_ops.reverse_sequence(
|
|
input=input_,
|
|
seq_lengths=seq_lengths,
|
|
seq_axis=seq_axis,
|
|
batch_axis=batch_axis)
|
|
else:
|
|
return array_ops.reverse(input_, axis=[seq_axis])
|
|
|
|
with vs.variable_scope("bw") as bw_scope:
|
|
|
|
def _map_reverse(inp):
|
|
return _reverse(
|
|
inp,
|
|
seq_lengths=sequence_length,
|
|
seq_axis=time_axis,
|
|
batch_axis=batch_axis)
|
|
|
|
inputs_reverse = nest.map_structure(_map_reverse, inputs)
|
|
tmp, output_state_bw = dynamic_rnn(
|
|
cell=cell_bw,
|
|
inputs=inputs_reverse,
|
|
sequence_length=sequence_length,
|
|
initial_state=initial_state_bw,
|
|
dtype=dtype,
|
|
parallel_iterations=parallel_iterations,
|
|
swap_memory=swap_memory,
|
|
time_major=time_major,
|
|
scope=bw_scope)
|
|
|
|
output_bw = _reverse(
|
|
tmp,
|
|
seq_lengths=sequence_length,
|
|
seq_axis=time_axis,
|
|
batch_axis=batch_axis)
|
|
|
|
outputs = (output_fw, output_bw)
|
|
output_states = (output_state_fw, output_state_bw)
|
|
|
|
return (outputs, output_states)
|