Update examples for converting TFLite models in 2.0.
PiperOrigin-RevId: 241822336
This commit is contained in:
parent
09deaeb03c
commit
72222a96ac
@ -46,9 +46,6 @@ tflite_model = converter.convert()
|
|||||||
The following example shows how to convert a SavedModel into a TensorFlow Lite
|
The following example shows how to convert a SavedModel into a TensorFlow Lite
|
||||||
`FlatBuffer`.
|
`FlatBuffer`.
|
||||||
|
|
||||||
Note: Due to a known issue with preserving input shapes with SavedModels,
|
|
||||||
`set_shape` needs to be called for all input tensors.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -69,9 +66,6 @@ model = tf.saved_model.load(export_dir)
|
|||||||
concrete_func = model.signatures[
|
concrete_func = model.signatures[
|
||||||
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||||
|
|
||||||
# Set the shape manually.
|
|
||||||
concrete_func.inputs[0].set_shape(input_data.shape)
|
|
||||||
|
|
||||||
# Convert the model.
|
# Convert the model.
|
||||||
converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func)
|
converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func)
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
@ -107,8 +101,9 @@ tflite_model = converter.convert()
|
|||||||
### End-to-end MobileNet conversion <a name="mobilenet"></a>
|
### End-to-end MobileNet conversion <a name="mobilenet"></a>
|
||||||
|
|
||||||
The following example shows how to convert and run inference on a pre-trained
|
The following example shows how to convert and run inference on a pre-trained
|
||||||
`tf.Keras` MobileNet model to TensorFlow Lite. In order to load the model from
|
`tf.Keras` MobileNet model to TensorFlow Lite. It compares the results of the
|
||||||
file, use `model_path` instead of `model_content`.
|
TensorFlow and TensorFlow Lite model on random data. In order to load the model
|
||||||
|
from file, use `model_path` instead of `model_content`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -118,15 +113,10 @@ import tensorflow as tf
|
|||||||
model = tf.keras.applications.MobileNetV2(
|
model = tf.keras.applications.MobileNetV2(
|
||||||
weights="imagenet", input_shape=(224, 224, 3))
|
weights="imagenet", input_shape=(224, 224, 3))
|
||||||
|
|
||||||
# Save and load the model to generate the concrete function to export.
|
# Create a concrete function to export.
|
||||||
export_dir = "/tmp/test_model/mobilenet"
|
to_save = tf.function(lambda x: model(x))
|
||||||
tf.saved_model.save(model, export_dir)
|
concrete_func = to_save.get_concrete_function(
|
||||||
model = tf.saved_model.load(export_dir)
|
tf.TensorSpec([1, 224, 224, 3], tf.float32))
|
||||||
concrete_func = model.signatures[
|
|
||||||
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
|
||||||
|
|
||||||
# Set the shape manually.
|
|
||||||
concrete_func.inputs[0].set_shape([1, 224, 224, 3])
|
|
||||||
|
|
||||||
# Convert the model.
|
# Convert the model.
|
||||||
converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func)
|
converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func)
|
||||||
@ -140,14 +130,20 @@ interpreter.allocate_tensors()
|
|||||||
input_details = interpreter.get_input_details()
|
input_details = interpreter.get_input_details()
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
|
|
||||||
# Test model on random input data.
|
# Test the TensorFlow Lite model on random input data.
|
||||||
input_shape = input_details[0]['shape']
|
input_shape = input_details[0]['shape']
|
||||||
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
|
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
|
||||||
interpreter.set_tensor(input_details[0]['index'], input_data)
|
interpreter.set_tensor(input_details[0]['index'], input_data)
|
||||||
|
|
||||||
interpreter.invoke()
|
interpreter.invoke()
|
||||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
tflite_results = interpreter.get_tensor(output_details[0]['index'])
|
||||||
print(output_data)
|
|
||||||
|
# Test the TensorFlow model on random input data.
|
||||||
|
tf_results = concrete_func(tf.constant(input_data))
|
||||||
|
|
||||||
|
# Compare the result.
|
||||||
|
for tf_result, tflite_result in zip(tf_results, tflite_results):
|
||||||
|
np.testing.assert_almost_equal(tf_result, tflite_result, decimal=5)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Summary of changes in `TFLiteConverter` between 1.X and 2.0 <a name="differences"></a>
|
## Summary of changes in `TFLiteConverter` between 1.X and 2.0 <a name="differences"></a>
|
||||||
@ -161,7 +157,7 @@ The following section summarizes the changes in `TFLiteConverter` from 1.X to
|
|||||||
`TFLiteConverter` in 2.0 supports SavedModels and Keras model files generated in
|
`TFLiteConverter` in 2.0 supports SavedModels and Keras model files generated in
|
||||||
both 1.X and 2.0. However, the conversion process no longer supports frozen
|
both 1.X and 2.0. However, the conversion process no longer supports frozen
|
||||||
`GraphDefs` generated in 1.X. Users who want to convert frozen `GraphDefs` to
|
`GraphDefs` generated in 1.X. Users who want to convert frozen `GraphDefs` to
|
||||||
TensorFlow Lite should use `tensorflow.compat.v1`.
|
TensorFlow Lite should use `tf.compat.v1.TFLiteConverter`.
|
||||||
|
|
||||||
### Quantization-aware training
|
### Quantization-aware training
|
||||||
|
|
||||||
@ -184,7 +180,7 @@ API is being reworked and streamlined in a direction that supports
|
|||||||
quantization-aware training through the Keras API. These attributes will be
|
quantization-aware training through the Keras API. These attributes will be
|
||||||
removed in the 2.0 API until the new quantization API is launched. Users who
|
removed in the 2.0 API until the new quantization API is launched. Users who
|
||||||
want to convert models generated by the rewriter function can use
|
want to convert models generated by the rewriter function can use
|
||||||
`tensorflow.compat.v1`.
|
`tf.compat.v1.TFLiteConverter`.
|
||||||
|
|
||||||
### Changes to attributes
|
### Changes to attributes
|
||||||
|
|
||||||
@ -234,15 +230,6 @@ import tensorflow.compat.v2 as tf
|
|||||||
tf.enable_v2_behavior()
|
tf.enable_v2_behavior()
|
||||||
```
|
```
|
||||||
|
|
||||||
### Using TensorFlow 1.X from a 2.0 installation <a name="use-1.X-from-2.0"></a>
|
|
||||||
|
|
||||||
TensorFlow 1.X can be enabled from 2.0 installation. This can be useful if you
|
|
||||||
are using features that are no longer supported in 2.0.
|
|
||||||
|
|
||||||
```python
|
|
||||||
import tensorflow.compat.v1 as tf
|
|
||||||
```
|
|
||||||
|
|
||||||
### Build from source code <a name="latest_package"></a>
|
### Build from source code <a name="latest_package"></a>
|
||||||
|
|
||||||
In order to run the latest version of the TensorFlow Lite Converter Python API,
|
In order to run the latest version of the TensorFlow Lite Converter Python API,
|
||||||
|
Loading…
Reference in New Issue
Block a user