diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml
index df004b12680..c74057e193c 100644
--- a/tensorflow/lite/g3doc/_book.yaml
+++ b/tensorflow/lite/g3doc/_book.yaml
@@ -77,9 +77,11 @@ upper_tabs:
- title: "Post-training quantization"
path: /lite/performance/post_training_quantization
- title: "Post-training quantization example"
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb
+ path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_quant.ipynb
- title: "Post-training integer quantization example"
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_integer_quant.ipynb
+ path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
+ - title: "Post-training float16 quantization example"
+ path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb
status: external
- title: "Delegates"
path: /lite/performance/delegates
diff --git a/tensorflow/lite/g3doc/guide/get_started.md b/tensorflow/lite/g3doc/guide/get_started.md
index a8f5daae9df..ce16b795ec9 100644
--- a/tensorflow/lite/g3doc/guide/get_started.md
+++ b/tensorflow/lite/g3doc/guide/get_started.md
@@ -272,11 +272,16 @@ following Python code quantizes a `SavedModel` and saves it to disk:
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_quantized_model)
```
+TensorFlow Lite supports reducing precision of values from full floating point
+to half-precision floats (float16) or 8-bit integers. There are trade-offs in
+model size and accuracy for each choice, and some operations have optimized
+implementations for these reduced precision types.
+
To learn more about quantization, see
[Post-training quantization](../performance/post_training_quantization.md).
diff --git a/tensorflow/lite/g3doc/performance/images/optimization.jpg b/tensorflow/lite/g3doc/performance/images/optimization.jpg
index 1a419f607d6..f866768509d 100644
Binary files a/tensorflow/lite/g3doc/performance/images/optimization.jpg and b/tensorflow/lite/g3doc/performance/images/optimization.jpg differ
diff --git a/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb
new file mode 100644
index 00000000000..22246f2edee
--- /dev/null
+++ b/tensorflow/lite/g3doc/performance/post_training_float16_quant.ipynb
@@ -0,0 +1,647 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "post-training-fp16-quant.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "private_outputs": true,
+ "collapsed_sections": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "6Y8E0lw5eYWm"
+ },
+ "source": [
+ "# Post Training FP16 Quantization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "CIGrZZPTZVeO"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "BTC1rDAuei_1"
+ },
+ "source": [
+ "## Overview\n",
+ "\n",
+ "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
+ "converting weights to 16-bit floating point values during model conversion from TensorFlow to TensorFlow Lite's flat buffer format. This results in a 2x reduction in model size. Some harware, like GPUs, can compute natively in this reduced precision arithmetic, realizing a speedup over traditional floating point execution. The Tensorflow Lite GPU delegate can be configured to run in this way. However, a model converted to float16 weights can still run on the CPU without additional modification: the float16 weights are upsampled to float32 prior to the first inference. This permits a significant reduction in model size in exchange for a minimal impacts to latency and accuracy.\n",
+ "\n",
+ "In this tutorial, we train an MNIST model from scratch, check its accuracy in TensorFlow, and then convert the saved model into a Tensorflow Lite flatbuffer\n",
+ "with float16 quantization. We finally check the\n",
+ "accuracy of the converted model and compare it to the original saved model. We\n",
+ "run the training script [mnist.py](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py) from\n",
+ "[Tensorflow official MNIST tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2XsEP17Zelz9"
+ },
+ "source": [
+ "## Building an MNIST model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "dDqqUIZjZjac"
+ },
+ "source": [
+ "### Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "gyqAw1M9lyab",
+ "colab": {}
+ },
+ "source": [
+ "! pip uninstall -y tensorflow\n",
+ "! pip install -U tf-nightly"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "WsN6s5L1ieNl",
+ "colab": {}
+ },
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "00U0taBoe-w7",
+ "colab": {}
+ },
+ "source": [
+ "! git clone --depth 1 https://github.com/tensorflow/models"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "c6nb7OPlXs_3",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "tf.lite.constants.FLOAT16"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "4XZPtSh-fUOc",
+ "colab": {}
+ },
+ "source": [
+ "import sys\n",
+ "import os\n",
+ "\n",
+ "if sys.version_info.major >= 3:\n",
+ " import pathlib\n",
+ "else:\n",
+ " import pathlib2 as pathlib\n",
+ "\n",
+ "# Add `models` to the python path.\n",
+ "models_path = os.path.join(os.getcwd(), \"models\")\n",
+ "sys.path.append(models_path)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "eQ6Q0qqKZogR"
+ },
+ "source": [
+ "### Train and export the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "eMsw_6HujaqM",
+ "colab": {}
+ },
+ "source": [
+ "saved_models_root = \"/tmp/mnist_saved_model\""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "hWSAjQWagIHl",
+ "colab": {}
+ },
+ "source": [
+ "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
+ "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "5NMaNZQCkW9X"
+ },
+ "source": [
+ "For the example, we only trained the model for a single epoch, so it only trains to ~96% accuracy.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xl8_fzVAZwOh"
+ },
+ "source": [
+ "### Convert to a TensorFlow Lite model\n",
+ "\n",
+ "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "Xp5oClaZkbtn",
+ "colab": {}
+ },
+ "source": [
+ "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
+ "saved_model_dir"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "AT8BgkKmljOy"
+ },
+ "source": [
+ "Using the [Python `TFLiteConverter`](https://www.tensorflow.org/lite/convert/python_api), the saved model can be converted into a TensorFlow Lite model.\n",
+ "\n",
+ "First load the model using the `TFLiteConverter`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "_i8B2nDZmAgQ",
+ "colab": {}
+ },
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "tf.logging.set_verbosity(tf.logging.DEBUG)\n",
+ "\n",
+ "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
+ "tflite_model = converter.convert()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "F2o2ZfF0aiCx"
+ },
+ "source": [
+ "Write it out to a `.tflite` file:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "vptWZq2xnclo",
+ "colab": {}
+ },
+ "source": [
+ "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
+ "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "Ie9pQaQrn5ue",
+ "colab": {}
+ },
+ "source": [
+ "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
+ "tflite_model_file.write_bytes(tflite_model)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "7BONhYtYocQY"
+ },
+ "source": [
+ "To instead quantize the model to float16 on export, first set the `optimizations` flag to use default optimizations. Then specify that float16 is the supported type on the target platform:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "HEZ6ET1AHAS3",
+ "colab": {}
+ },
+ "source": [
+ "tf.logging.set_verbosity(tf.logging.INFO)\n",
+ "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
+ "converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xW84iMYjHd9t",
+ "colab_type": "text"
+ },
+ "source": [
+ "Finally, convert the model like usual. Note, by default the converted model will still use float input and outputs for invocation convenience."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "yuNfl3CoHNK3",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "tflite_fp16_model = converter.convert()\n",
+ "tflite_model_fp16_file = tflite_models_dir/\"mnist_model_quant_f16.tflite\"\n",
+ "tflite_model_fp16_file.write_bytes(tflite_fp16_model)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "PhMmUTl4sbkz"
+ },
+ "source": [
+ "Note how the resulting file is approximately `1/2` the size."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "JExfcfLDscu4",
+ "colab": {}
+ },
+ "source": [
+ "!ls -lh {tflite_models_dir}"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L8lQHMp_asCq"
+ },
+ "source": [
+ "## Run the TensorFlow Lite models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "-5l6-ciItvX6"
+ },
+ "source": [
+ "We can run the TensorFlow Lite model using the Python TensorFlow Lite\n",
+ "Interpreter. \n",
+ "\n",
+ "### Load the test data\n",
+ "\n",
+ "First, let's load the MNIST test data to feed to the model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "eTIuU07NuKFL",
+ "colab": {}
+ },
+ "source": [
+ "import numpy as np\n",
+ "_, mnist_test = tf.keras.datasets.mnist.load_data()\n",
+ "images, labels = tf.cast(mnist_test[0], tf.float32)/255.0, mnist_test[1]\n",
+ "\n",
+ "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Ap_jE7QRvhPf"
+ },
+ "source": [
+ "### Load the model into the interpreters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "Jn16Rc23zTss",
+ "colab": {}
+ },
+ "source": [
+ "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n",
+ "interpreter.allocate_tensors()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "J8Pztk1mvNVL",
+ "colab": {}
+ },
+ "source": [
+ "interpreter_fp16 = tf.lite.Interpreter(model_path=str(tflite_model_fp16_file))\n",
+ "interpreter_fp16.allocate_tensors()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2opUt_JTdyEu"
+ },
+ "source": [
+ "### Test the models on one image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "AKslvo2kwWac",
+ "colab": {}
+ },
+ "source": [
+ "for img, label in mnist_ds:\n",
+ " break\n",
+ "\n",
+ "interpreter.set_tensor(interpreter.get_input_details()[0][\"index\"], img)\n",
+ "interpreter.invoke()\n",
+ "predictions = interpreter.get_tensor(\n",
+ " interpreter.get_output_details()[0][\"index\"])"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "XZClM2vo3_bm",
+ "colab": {}
+ },
+ "source": [
+ "import matplotlib.pylab as plt\n",
+ "\n",
+ "plt.imshow(img[0])\n",
+ "template = \"True:{true}, predicted:{predict}\"\n",
+ "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+ " predict=str(predictions[0])))\n",
+ "plt.grid(False)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "3gwhv4lKbYZ4",
+ "colab": {}
+ },
+ "source": [
+ "interpreter_fp16.set_tensor(\n",
+ " interpreter_fp16.get_input_details()[0][\"index\"], img)\n",
+ "interpreter_fp16.invoke()\n",
+ "predictions = interpreter_fp16.get_tensor(\n",
+ " interpreter_fp16.get_output_details()[0][\"index\"])"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "CIH7G_MwbY2x",
+ "colab": {}
+ },
+ "source": [
+ "plt.imshow(img[0])\n",
+ "template = \"True:{true}, predicted:{predict}\"\n",
+ "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+ " predict=str(predictions[0])))\n",
+ "plt.grid(False)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "LwN7uIdCd8Gw"
+ },
+ "source": [
+ "### Evaluate the models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "05aeAuWjvjPx",
+ "colab": {}
+ },
+ "source": [
+ "def eval_model(interpreter, mnist_ds):\n",
+ " total_seen = 0\n",
+ " num_correct = 0\n",
+ "\n",
+ " input_index = interpreter.get_input_details()[0][\"index\"]\n",
+ " output_index = interpreter.get_output_details()[0][\"index\"]\n",
+ " for img, label in mnist_ds:\n",
+ " total_seen += 1\n",
+ " interpreter.set_tensor(input_index, img)\n",
+ " interpreter.invoke()\n",
+ " predictions = interpreter.get_tensor(output_index)\n",
+ " if predictions == label.numpy():\n",
+ " num_correct += 1\n",
+ "\n",
+ " if total_seen % 500 == 0:\n",
+ " print(\"Accuracy after %i images: %f\" %\n",
+ " (total_seen, float(num_correct) / float(total_seen)))\n",
+ "\n",
+ " return float(num_correct) / float(total_seen)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "T5mWkSbMcU5z",
+ "colab": {}
+ },
+ "source": [
+ "print(eval_model(interpreter, mnist_ds))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Km3cY9ry8ZlG"
+ },
+ "source": [
+ "We can repeat the evaluation on the float16 quantized model to obtain:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "-9cnwiPp6EGm",
+ "colab": {}
+ },
+ "source": [
+ "# NOTE: Colab runs on server CPUs. At the time of writing this, TensorFlow Lite\n",
+ "# doesn't have super optimized server CPU kernels. For this reason this may be\n",
+ "# slower than the above float interpreter. But for mobile CPUs, considerable\n",
+ "# speedup can be observed.\n",
+ "print(eval_model(interpreter_fp16, mnist_ds))\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L7lfxkor8pgv"
+ },
+ "source": [
+ "In this example, we have quantized a model to float16 with no difference in the accuracy.\n",
+ "\n",
+ "It's also possible to evaluate the fp16 quantized model on the GPU. To perform all arithmetic with the reduced precision values, be sure to create the `TfLiteGPUDelegateOptions` struct in your app and set `precision_loss_allowed` to `1`, like this:\n",
+ "\n",
+ "```\n",
+ "//Prepare GPU delegate.\n",
+ "const TfLiteGpuDelegateOptions options = {\n",
+ " .metadata = NULL,\n",
+ " .compile_options = {\n",
+ " .precision_loss_allowed = 1, // FP16\n",
+ " .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST,\n",
+ " .dynamic_batch_enabled = 0, // Not fully functional yet\n",
+ " },\n",
+ "};\n",
+ "```\n",
+ "\n",
+ "Detailed documentation on the TFLite GPU delegate and how to use it in your application can be found [here](https://www.tensorflow.org/lite/performance/gpu_advanced?source=post_page---------------------------)"
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/tensorflow/lite/tutorials/post_training_integer_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
similarity index 100%
rename from tensorflow/lite/tutorials/post_training_integer_quant.ipynb
rename to tensorflow/lite/g3doc/performance/post_training_integer_quant.ipynb
diff --git a/tensorflow/lite/tutorials/post_training_quant.ipynb b/tensorflow/lite/g3doc/performance/post_training_quant.ipynb
similarity index 100%
rename from tensorflow/lite/tutorials/post_training_quant.ipynb
rename to tensorflow/lite/g3doc/performance/post_training_quant.ipynb
diff --git a/tensorflow/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md
index 69ebf7ee4a0..30f8c0992e0 100644
--- a/tensorflow/lite/g3doc/performance/post_training_quantization.md
+++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md
@@ -8,6 +8,20 @@ conversion.
### Optimization options
+There are several post training quantization options to choose from. Here is a
+summary table of the choices and the benefits they provide:
+
+| Technique | Benefits | Hardware |
+| ---------------------- | ------------------------- | ------------------- |
+| Post training "hybrid" | 4x smaller, 2-3x speedup, | CPU |
+: : accuracy : :
+| Post training integer | 4x smaller, More speedup | CPU, Edge TPU, etc. |
+| Post training fp16 | 2x smaller, Potential GPU | CPU/GPU |
+: : acceleration : :
+
+This decision tree can help determine which post-training quantization method is
+best for your use case:
+

### Quantizing weights
@@ -78,6 +92,35 @@ Note: `target_spec.supported_ops` was previously `target_ops` in the Python API.
This makes the converter throw an error if it encounters an operation it cannot
currently quantize.
+### Float16 quantization of weights
+
+We can reduce the size of a floating point model by quantizing the weights to
+float16, the IEEE standard for 16 bit floating point numbers. The advantages of
+this quantization are as follows:
+
+- reduce model size by up to half (since all weights are now half the original
+ size)
+- minimal loss in accuracy
+- some delegates (e.g. the GPU delegate) can operate directly on float16 data,
+ which results in faster execution than float32 computations.
+
+This quantization may not be a good choice if you need maximum performance (a
+quantization to fixed point math would be better in that case). To enable
+float16 quantization of weights, specify "DEFAULT" optimization as above and
+then specify that float16 is in supported types for the target_spec:
+
+```
+import tensorflow as tf
+converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]
+tflite_quant_model = converter.convert()
+```
+
+By default, a float16 quantized model will "dequantize" the weights values to
+float32 when run on the CPU. The GPU delegate will not perform this
+dequantization, since it can operate on float16 data.
+
### Model accuracy
Since weights are quantized post training, there could be an accuracy loss,