Update examples for converting TFLite models in 2.0.
PiperOrigin-RevId: 241822336
This commit is contained in:
parent
09deaeb03c
commit
72222a96ac
@ -46,9 +46,6 @@ tflite_model = converter.convert()
|
||||
The following example shows how to convert a SavedModel into a TensorFlow Lite
|
||||
`FlatBuffer`.
|
||||
|
||||
Note: Due to a known issue with preserving input shapes with SavedModels,
|
||||
`set_shape` needs to be called for all input tensors.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
|
||||
@ -69,9 +66,6 @@ model = tf.saved_model.load(export_dir)
|
||||
concrete_func = model.signatures[
|
||||
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
|
||||
# Set the shape manually.
|
||||
concrete_func.inputs[0].set_shape(input_data.shape)
|
||||
|
||||
# Convert the model.
|
||||
converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func)
|
||||
tflite_model = converter.convert()
|
||||
@ -107,8 +101,9 @@ tflite_model = converter.convert()
|
||||
### End-to-end MobileNet conversion <a name="mobilenet"></a>
|
||||
|
||||
The following example shows how to convert and run inference on a pre-trained
|
||||
`tf.Keras` MobileNet model to TensorFlow Lite. In order to load the model from
|
||||
file, use `model_path` instead of `model_content`.
|
||||
`tf.Keras` MobileNet model to TensorFlow Lite. It compares the results of the
|
||||
TensorFlow and TensorFlow Lite model on random data. In order to load the model
|
||||
from file, use `model_path` instead of `model_content`.
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
@ -118,15 +113,10 @@ import tensorflow as tf
|
||||
model = tf.keras.applications.MobileNetV2(
|
||||
weights="imagenet", input_shape=(224, 224, 3))
|
||||
|
||||
# Save and load the model to generate the concrete function to export.
|
||||
export_dir = "/tmp/test_model/mobilenet"
|
||||
tf.saved_model.save(model, export_dir)
|
||||
model = tf.saved_model.load(export_dir)
|
||||
concrete_func = model.signatures[
|
||||
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
||||
|
||||
# Set the shape manually.
|
||||
concrete_func.inputs[0].set_shape([1, 224, 224, 3])
|
||||
# Create a concrete function to export.
|
||||
to_save = tf.function(lambda x: model(x))
|
||||
concrete_func = to_save.get_concrete_function(
|
||||
tf.TensorSpec([1, 224, 224, 3], tf.float32))
|
||||
|
||||
# Convert the model.
|
||||
converter = tf.lite.TFLiteConverter.from_concrete_function(concrete_func)
|
||||
@ -140,14 +130,20 @@ interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()
|
||||
output_details = interpreter.get_output_details()
|
||||
|
||||
# Test model on random input data.
|
||||
# Test the TensorFlow Lite model on random input data.
|
||||
input_shape = input_details[0]['shape']
|
||||
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
|
||||
interpreter.set_tensor(input_details[0]['index'], input_data)
|
||||
|
||||
interpreter.invoke()
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
print(output_data)
|
||||
tflite_results = interpreter.get_tensor(output_details[0]['index'])
|
||||
|
||||
# Test the TensorFlow model on random input data.
|
||||
tf_results = concrete_func(tf.constant(input_data))
|
||||
|
||||
# Compare the result.
|
||||
for tf_result, tflite_result in zip(tf_results, tflite_results):
|
||||
np.testing.assert_almost_equal(tf_result, tflite_result, decimal=5)
|
||||
```
|
||||
|
||||
## Summary of changes in `TFLiteConverter` between 1.X and 2.0 <a name="differences"></a>
|
||||
@ -161,7 +157,7 @@ The following section summarizes the changes in `TFLiteConverter` from 1.X to
|
||||
`TFLiteConverter` in 2.0 supports SavedModels and Keras model files generated in
|
||||
both 1.X and 2.0. However, the conversion process no longer supports frozen
|
||||
`GraphDefs` generated in 1.X. Users who want to convert frozen `GraphDefs` to
|
||||
TensorFlow Lite should use `tensorflow.compat.v1`.
|
||||
TensorFlow Lite should use `tf.compat.v1.TFLiteConverter`.
|
||||
|
||||
### Quantization-aware training
|
||||
|
||||
@ -184,7 +180,7 @@ API is being reworked and streamlined in a direction that supports
|
||||
quantization-aware training through the Keras API. These attributes will be
|
||||
removed in the 2.0 API until the new quantization API is launched. Users who
|
||||
want to convert models generated by the rewriter function can use
|
||||
`tensorflow.compat.v1`.
|
||||
`tf.compat.v1.TFLiteConverter`.
|
||||
|
||||
### Changes to attributes
|
||||
|
||||
@ -234,15 +230,6 @@ import tensorflow.compat.v2 as tf
|
||||
tf.enable_v2_behavior()
|
||||
```
|
||||
|
||||
### Using TensorFlow 1.X from a 2.0 installation <a name="use-1.X-from-2.0"></a>
|
||||
|
||||
TensorFlow 1.X can be enabled from 2.0 installation. This can be useful if you
|
||||
are using features that are no longer supported in 2.0.
|
||||
|
||||
```python
|
||||
import tensorflow.compat.v1 as tf
|
||||
```
|
||||
|
||||
### Build from source code <a name="latest_package"></a>
|
||||
|
||||
In order to run the latest version of the TensorFlow Lite Converter Python API,
|
||||
|
Loading…
Reference in New Issue
Block a user