Update RNN conversion tflite g3doc

This uses the content from the blog post/dogfood announcement email

PiperOrigin-RevId: 317248288
Change-Id: I210c64bd54c70aa5b68742d59d6d36fa154e856c
This commit is contained in:
Ashwin Murthy 2020-06-18 22:10:41 -07:00 committed by TensorFlower Gardener
parent cfbdd27fe3
commit b7caba2c42

View File

@ -1,99 +1,193 @@
# Convert RNN models
# TensorFlow RNN conversion to TensorFlow Lite
The TensorFlow Lite interpreter currently implements a subset of TensorFlow
operations, meaning some model architectures cannot immediately be converted due
to missing operations.
## Overview
Some RNN-based architectures are affected by this. The following document
outlines the current state of play and provides strategies for converting RNN
models.
TensorFlow Lite supports converting TensorFlow RNN models to TensorFlow Lites
fused LSTM operators. Fused operators exist to maximize the performance of their
underlying kernel implementations, as well as provide a higher level interface
to define complex transformations like quantizatization.
## Currently supported
Since there are many variants of RNN APIs in TensorFlow, our approach has been
two fold:
Currently, RNN models using
[`tf.compat.v1.nn.static_rnn`](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
can be converted successfully as long as no `sequence_length` is specified.
1. Provide **native support for standard TensorFlow RNN APIs** like Keras LSTM.
This is the recommended option.
1. Provide an **interface** **into the conversion infrastructure for**
**user-defined** **RNN implementations** to plug in and get converted to
TensorFlow Lite. We provide a couple of out of box examples of such
conversion using lingvos
[LSTMCellSimple](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L123)
and
[LayerNormalizedLSTMCellSimple](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc#L519)
RNN interfaces.
The following `tf.compat.v1.nn.rnn_cell` operations work with
`tf.compat.v1.nn.static_rnn`:
## Converter API
* [tf.compat.v1.nn.rnn_cell.LSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMCell)
* [tf.compat.v1.nn.rnn_cell.RNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell)
* [tf.compat.v1.nn.rnn_cell.GRUCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell)
* [tf.compat.v1.nn.rnn_cell.BasicLSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicLSTMCell)
* [tf.compat.v1.nn.rnn_cell.BasicRNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicRNNCell)
Currently this feature is available through the
[tf-nightly](https://pypi.org/project/tf-nightly/) pip or from head. This will
be available in the TensorFlow 2.3 release.
In addition, TensorFlow Lite provides some experimental drop-in replacements for
RNN operations that enable dynamic RNN architectures with TensorFlow Lite.
This conversion functionality is available when converting to TensorFlow Lite
via a SavedModel or from the Keras model directly. See example usages.
Drop-in replacements are available for the following:
### From saved model
* [tf.compat.v1.nn.dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
* [tf.compat.v1.nn.bidirectional_dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
* [tf.compat.v1.nn.rnn_cell.RNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell)
* [tf.compat.v1.nn.rnn_cell.LSTMCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMCell)
```
# build a saved model. Here concrete_function is the exported function
# corresponding to the TensorFlow model containing one or more
# Keras LSTM layers.
saved_model, saved_model_dir = build_saved_model_lstm(...)
saved_model.save(saved_model_dir, save_format="tf", signatures=concrete_func)
## Not currently supported
# Convert the model.
converter = TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
```
TensorFlow Lite does not currently support
[Control Flow](https://www.tensorflow.org/api_docs/cc/group/control-flow-ops)
operations. This means that, unless one of the conversion strategies discussed
in the next section are employed, models built with the following TensorFlow
functions will not convert successfully:
### From Keras model
* [tf.compat.v1.nn.static_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
where a `sequence_length` is specified
* [tf.compat.v1.nn.dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
* [tf.compat.v1.nn.bidirectional_dynamic_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
```
# build a Keras model
keras_model = build_keras_lstm(...)
Note: TensorFlow Lite plans to implement all required Control Flow operations by
the end of 2019. At this point, all RNN architectures will convert successfully.
# Convert the model.
converter = TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()
## Conversion strategies
```
To convert an RNN model that uses the functions specified above, you will have
to modify its architecture and retrain it. The following strategies can be used.
## Example
### 1. Refactoring
Keras LSTM to TensorFlow Lite
[Colab](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/examples/experimental_new_converter/Keras_LSTM_fusion_Codelab.ipynb)
illustrates the end to end usage with the TensorFlow Lite interpreter.
The simplest approach, if possible, is to refactor the model architecture to use
[tf.compat.v1.nn.static_rnn](https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn)
without `sequence_length`.
## TensorFlow RNNs APIs supported
### 2. Drop-in replacements that use op hints and fused ops
### Keras LSTM conversion (recommended)
TensorFlow Lite provides the some experimental drop-in replacements for RNN
operations that enable dynamic RNN architectures with TensorFlow Lite. Using
[OpHints](https://www.tensorflow.org/lite/guide/ops_custom#converting_tensorflow_models_to_convert_graphs),
they run normally during training, but are substituted with special fused ops
when run by the Lite interpreter.
We support out-of-the-box conversion of Keras LSTM to TensorFlow Lite. For
details on how this works please refer to the
[Keras LSTM interface](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/examples/experimental_new_converter/Keras_LSTM_fusion_Codelab.ipynb)<span style="text-decoration:space;">
</span>and to the conversion logic
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc#L627).
The following drop-in replacements are available:
Also important is to highlight the TensorFlow Lites LSTM contract with respect
to the Keras operation definition:
* [tf.compat.v1.lite.experimental.nn.dynamic_rnn](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn.py#L41)
* replacement for tf.nn.dynamic_rnn
* [tf.compat.v1.lite.experimental.nn.bidirectional_dynamic_rnn](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn.py#L279)
* replacement for tf.nn.bidirectional_dynamic_rnn
* [tf.compat.v1.lite.experimental.nn.TfLiteRNNCell](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn_cell.py#L39)
* replacement for tf.nn.rnn_cell.RNNCell
* [tf.compat.v1.lite.experimental.nn.TfLiteLSTMCell](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/rnn_cell.py#L159)
* replacement for tf.nn.rnn_cell.LSTMCell
1. The dimension 0 of the input tensor is the batch size.
1. The dimension 0 of the recurrent\_weight tensor is the number of outputs.
1. The **weight** and **recurrent\_kernel** tensors are transposed.
1. The transposed weight, transposed recurrent\_kernel and bias tensors are
split into 4 equal sized tensors along the dimension 0. These correspond to
**input gate, forget gate, cell, and output gate**.
Note: These replacements must be used together. For example, if you are using
`tf.compat.v1.lite.experimental.nn.dynamic_rnn`, you must combine it with
`tf.compat.v1.lite.experimental.nn.TfLiteRNNCell` instead of using
`tf.compat.v1.nn.rnn_cell.RNNCell`.
#### Keras LSTM Variants
Instead of
[tf.compat.v1.nn.rnn_cell.MultiRNNCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/MultiRNNCell),
you should use
[tf.compat.v1.keras.layers.StackedRNNCells](https://www.tensorflow.org/api_docs/python/tf/keras/layers/StackedRNNCells).
##### Time major
For a tutorial on using these replacements, see
[TensorFlow Lite LSTM ops API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/g3doc/README.md).
Users may choose time-major or no time-major. Keras LSTM adds a time-major
attribute in the function def attributes. For Unidirectional sequence LSTM, we
can simply map to unidirecional\_sequence\_lstm's
[time major attribute](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L3508).
For a Colab demonstrating these classes, refer to
[TensorFlowLite_LSTM_Keras_Tutorial](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/examples/lstm/TensorFlowLite_LSTM_Keras_Tutorial.ipynb).
##### BiDirectional LSTM
Note: There is no replacement available for
[tf.compat.v1.nn.rnn_cell.GRUCell](https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/GRUCell).
Bidirectional LSTM can be implemented with two Keras LSTM layers, one for
forward and one for backward, see examples
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/wrappers.py#L381).
Once we see the go\_backward attribute, we recognize it as backward LSTM, then
we group forward & backward LSTM together. **This is future work.** Currently,
this creates two UnidirectionalSequenceLSTM operators in the TensorFlow Lite
model.
### User-defined LSTM conversion examples
TensorFlow Lite also provides a way to convert user defined LSTM
implementations. Here we use Lingvos LSTM as an example of how that can be
implemented. For details please refer to the
[lingvo.LSTMCellSimple interface](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_cell.py#L230)
and the conversion logic
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L123).
We also provide an example for another of Lingvos LSTM definitions in
[lingvo.LayerNormalizedLSTMCellSimple interface](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_cell.py#L1179)
and its convertion logic
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L130).
## “Bring your own TensorFlow RNN” to TensorFlow Lite
If a user's RNN interface is different from the standard supported ones, there
are a couple of options:
**Option 1:** Write adapter code in TensorFlow python to adapt the RNN interface
to the Keras RNN interface. This means a tf.function with
[tf\_implements annotation](https://github.com/tensorflow/community/pull/113) on
the generated RNN interfaces function that is identical to the one generated by
the Keras LSTM layer. After this, the same conversion API used for Keras LSTM
will work.
**Option 2:** If the above is not possible (e.g. the Keras LSTM is missing some
functionality that is currently exposed by TensorFlow Lites fused LSTM op like
layer normalization), then extend the TensorFlow Lite converter by writing
custom conversion code and plug it into the prepare-composite-functions
MLIR-pass
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L108).
The functions interface should be treated like an API contract and should
contain the arguments needed to convert to fused TensorFlow Lite LSTM
operators - i.e. input, bias, weights, projection, layer normalization, etc. It
is preferable for the tensors passed as arguments to this function to have known
rank (i.e. RankedTensorType in MLIR). This makes it much easier to write
conversion code that can assume these tensors as RankedTensorType and helps
transform them to ranked tensors corresponding to the fused TensorFlow Lite
operators operands.
A complete example of such conversion flow is Lingvos LSTMCellSimple to
TensorFlow Lite conversion.
The LSTMCellSimple in Lingvo is defined
[here](https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_cell.py#L230).
Models trained with this LSTM cell can be converted to TensorFlow Lite as
follows:
1. Wrap all uses of LSTMCellSimple in a tf.function with a tf\_implements
annotation that is labelled as such (e.g. lingvo.LSTMCellSimple would be a
good annotation name here). Make sure the tf.function that is generated
matches the interface of the function expected in the conversion code. This
is a contract between the model author adding the annotation and the
conversion code.
1. Extend the prepare-composite-functions pass to plug in a custom composite op
to TensorFlow Lite fused LSTM op conversion. See
[LSTMCellSimple](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L123)
conversion code.
The conversion contract:
1. **Weight** and **projection** tensors are transposed.
1. The **{input, recurrent}** to **{cell, input gate, forget gate, output
gate}** are extracted by slicing the transposed weight tensor.
1. The **{bias}** to **{cell, input gate, forget gate, output gate}** are
extracted by slicing the bias tensor.
1. The **projection** is extracted by slicing the transposed projection tensor.
1. Similar conversion is written for
[LayerNormalizedLSTMCellSimple](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc#L519).
1. The rest of the TensorFlow Lite conversion infrastructure, including all the
[MLIR passes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc#L58)
defined as well as the final export to TensorFlow Lite flatbuffer can be
reused.
## Known issues/limitations
1. Currently there is support only for converting stateless Keras LSTM (default
behavior in Keras). Stateful Keras LSTM conversion is future work.
1. It is still possible to model a stateful Keras LSTM layer using the
underlying stateless Keras LSTM layer and managing the state explicitly in
the user program. Such a TensorFlow program can still be converted to
TensorFlow Lite using the feature being described here.
1. Bidirectional LSTM is currently modelled as two UnidirectionalSequenceLSTM
operators in TensorFlow Lite. This will be replaced with a single
BidirectionalSequenceLSTM op.