Update examples for converting TFLite models in 2.0.

PiperOrigin-RevId: 241822336
This commit is contained in:
Nupur Garg 2019-04-03 15:51:09 -07:00 committed by TensorFlower Gardener
parent 09deaeb03c
commit 72222a96ac

View File

@ -46,9 +46,6 @@ tflite_model = converter.convert()
The following example shows how to convert a SavedModel into a TensorFlow Lite The following example shows how to convert a SavedModel into a TensorFlow Lite
`FlatBuffer`. `FlatBuffer`.
Note: Due to a known issue with preserving input shapes with SavedModels,
`set_shape` needs to be called for all input tensors.
```python ```python
import tensorflow as tf import tensorflow as tf
@ -69,9 +66,6 @@ model = tf.saved_model.load(export_dir)
concrete_func = model.signatures[ concrete_func = model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
# Set the shape manually.
concrete_func.inputs[0].set_shape(input_data.shape)
# Convert the model. # Convert the model.
converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func) converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func)
tflite_model = converter.convert() tflite_model = converter.convert()
@ -107,8 +101,9 @@ tflite_model = converter.convert()
### End-to-end MobileNet conversion <a name="mobilenet"></a> ### End-to-end MobileNet conversion <a name="mobilenet"></a>
The following example shows how to convert and run inference on a pre-trained The following example shows how to convert and run inference on a pre-trained
`tf.Keras` MobileNet model to TensorFlow Lite. In order to load the model from `tf.Keras` MobileNet model to TensorFlow Lite. It compares the results of the
file, use `model_path` instead of `model_content`. TensorFlow and TensorFlow Lite model on random data. In order to load the model
from file, use `model_path` instead of `model_content`.
```python ```python
import numpy as np import numpy as np
@ -118,15 +113,10 @@ import tensorflow as tf
model = tf.keras.applications.MobileNetV2( model = tf.keras.applications.MobileNetV2(
weights="imagenet", input_shape=(224, 224, 3)) weights="imagenet", input_shape=(224, 224, 3))
# Save and load the model to generate the concrete function to export. # Create a concrete function to export.
export_dir = "/tmp/test_model/mobilenet" to_save = tf.function(lambda x: model(x))
tf.saved_model.save(model, export_dir) concrete_func = to_save.get_concrete_function(
model = tf.saved_model.load(export_dir) tf.TensorSpec([1, 224, 224, 3], tf.float32))
concrete_func = model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
# Set the shape manually.
concrete_func.inputs[0].set_shape([1, 224, 224, 3])
# Convert the model. # Convert the model.
converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func) converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func)
@ -140,14 +130,20 @@ interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details() output_details = interpreter.get_output_details()
# Test model on random input data. # Test the TensorFlow Lite model on random input data.
input_shape = input_details[0]['shape'] input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke() interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index']) tflite_results = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
# Test the TensorFlow model on random input data.
tf_results = concrete_func(tf.constant(input_data))
# Compare the result.
for tf_result, tflite_result in zip(tf_results, tflite_results):
np.testing.assert_almost_equal(tf_result, tflite_result, decimal=5)
``` ```
## Summary of changes in `TFLiteConverter` between 1.X and 2.0 <a name="differences"></a> ## Summary of changes in `TFLiteConverter` between 1.X and 2.0 <a name="differences"></a>
@ -161,7 +157,7 @@ The following section summarizes the changes in `TFLiteConverter` from 1.X to
`TFLiteConverter` in 2.0 supports SavedModels and Keras model files generated in `TFLiteConverter` in 2.0 supports SavedModels and Keras model files generated in
both 1.X and 2.0. However, the conversion process no longer supports frozen both 1.X and 2.0. However, the conversion process no longer supports frozen
`GraphDefs` generated in 1.X. Users who want to convert frozen `GraphDefs` to `GraphDefs` generated in 1.X. Users who want to convert frozen `GraphDefs` to
TensorFlow Lite should use `tensorflow.compat.v1`. TensorFlow Lite should use `tf.compat.v1.TFLiteConverter`.
### Quantization-aware training ### Quantization-aware training
@ -184,7 +180,7 @@ API is being reworked and streamlined in a direction that supports
quantization-aware training through the Keras API. These attributes will be quantization-aware training through the Keras API. These attributes will be
removed in the 2.0 API until the new quantization API is launched. Users who removed in the 2.0 API until the new quantization API is launched. Users who
want to convert models generated by the rewriter function can use want to convert models generated by the rewriter function can use
`tensorflow.compat.v1`. `tf.compat.v1.TFLiteConverter`.
### Changes to attributes ### Changes to attributes
@ -234,15 +230,6 @@ import tensorflow.compat.v2 as tf
tf.enable_v2_behavior() tf.enable_v2_behavior()
``` ```
### Using TensorFlow 1.X from a 2.0 installation <a name="use-1.X-from-2.0"></a>
TensorFlow 1.X can be enabled from 2.0 installation. This can be useful if you
are using features that are no longer supported in 2.0.
```python
import tensorflow.compat.v1 as tf
```
### Build from source code <a name="latest_package"></a> ### Build from source code <a name="latest_package"></a>
In order to run the latest version of the TensorFlow Lite Converter Python API, In order to run the latest version of the TensorFlow Lite Converter Python API,