Add super resolution TFLite sample notebook

PiperOrigin-RevId: 345229956
Change-Id: Ib2b28bfc383bf4767b2791c02c510443660f3bbb
This commit is contained in:
A. Unique TensorFlower 2020-12-02 07:38:30 -08:00 committed by TensorFlower Gardener
parent ec0d1fd14d
commit cc8da3cc8f
2 changed files with 347 additions and 0 deletions

View File

@ -216,6 +216,8 @@ upper_tabs:
path: /lite/models/segmentation/overview
- title: "Style transfer"
path: /lite/models/style_transfer/overview
- title: "Super resolution"
path: /lite/models/super_resolution/overview
- heading: "Text"
- title: "BERT Question Answer"
path: /lite/models/bert_qa/overview

View File

@ -0,0 +1,345 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "JfOIB1KdkbYW"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "Ojb0aXCmBgo7"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M9Y4JZ0ZGoE4"
},
"source": [
"# Super resolution with TensorFlow Lite"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/models/super_resolution/overview\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models/super_resolution/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models/super_resolution/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/models/super_resolution/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://tfhub.dev/captain-pool/esrgan-tf2/1\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub model</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-uF3N4BbaMvA"
},
"source": [
"## Overview"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "isbXET4vVHfu"
},
"source": [
"The task of recovering a high resolution (HR) image from its low resolution counterpart is commonly referred to as Single Image Super Resolution (SISR). \n",
"\n",
"The model used here is ESRGAN\n",
"([ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks](https://arxiv.org/abs/1809.00219)). And we are going to use TensorFlow Lite to run inference on the pretrained model.\n",
"\n",
"The TFLite model is converted from this\n",
"[implementation](https://tfhub.dev/captain-pool/esrgan-tf2/1) hosted on TF Hub. Note that the model we converted upsamples a 50x50 low resolution image to a 200x200 high resolution image (scale factor=4). If you want a different input size or scale factor, you need to re-convert or re-train the original model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2dQlTqiffuoU"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qKyMtsGqu3zH"
},
"source": [
"Let's install required libraries first."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7YTT1Rxsw3A9"
},
"outputs": [],
"source": [
"!pip install matplotlib tensorflow tensorflow-hub"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Clz5Kl97FswD"
},
"source": [
"Import dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2xh1kvGEBjuP"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"import matplotlib.pyplot as plt\n",
"print(tf.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i5miVfL4kxTA"
},
"source": [
"Download and convert the ESRGAN model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X5PvXIXRwvHj"
},
"outputs": [],
"source": [
"model = hub.load(\"https://tfhub.dev/captain-pool/esrgan-tf2/1\")\n",
"concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n",
"concrete_func.inputs[0].set_shape([1, 50, 50, 3])\n",
"converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])\n",
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"tflite_model = converter.convert()\n",
"\n",
"# Save the TF Lite model.\n",
"with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:\n",
" f.write(tflite_model)\n",
"\n",
"esrgan_model_path = './ESRGAN.tflite'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jH5-xPkyUEqt"
},
"source": [
"Download a test image (insect head)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "suWiStTWgK6e"
},
"outputs": [],
"source": [
"test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rgQ4qRuFNpyW"
},
"source": [
"## Generate a super resolution image using TensorFlow Lite"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J9FV4btf02-2"
},
"outputs": [],
"source": [
"lr = tf.io.read_file(test_img_path)\n",
"lr = tf.image.decode_jpeg(lr)\n",
"lr = tf.expand_dims(lr, axis=0)\n",
"lr = tf.cast(lr, tf.float32)\n",
"\n",
"# Load TFLite model and allocate tensors.\n",
"interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)\n",
"interpreter.allocate_tensors()\n",
"\n",
"# Get input and output tensors.\n",
"input_details = interpreter.get_input_details()\n",
"output_details = interpreter.get_output_details()\n",
"\n",
"# Run the model\n",
"interpreter.set_tensor(input_details[0]['index'], lr)\n",
"interpreter.invoke()\n",
"\n",
"# Extract the output and postprocess it\n",
"output_data = interpreter.get_tensor(output_details[0]['index'])\n",
"sr = tf.squeeze(output_data, axis=0)\n",
"sr = tf.clip_by_value(sr, 0, 255)\n",
"sr = tf.round(sr)\n",
"sr = tf.cast(sr, tf.uint8)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EwddQrDUNQGO"
},
"source": [
"## Visualize the result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aasKuozt1gNd"
},
"outputs": [],
"source": [
"lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)\n",
"plt.figure(figsize = (1, 1))\n",
"plt.title('LR')\n",
"plt.imshow(lr.numpy());\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
"plt.subplot(1, 2, 1) \n",
"plt.title(f'ESRGAN (x4)')\n",
"plt.imshow(sr.numpy());\n",
"\n",
"bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)\n",
"bicubic = tf.cast(bicubic, tf.uint8)\n",
"plt.subplot(1, 2, 2) \n",
"plt.title('Bicubic')\n",
"plt.imshow(bicubic.numpy());"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kb-fkogObjq"
},
"source": [
"## Performance Benchmarks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tNzdgpqTy5P3"
},
"source": [
"Performance benchmark numbers are generated with the tool\n",
"[described here](https://www.tensorflow.org/lite/performance/benchmarks).\n",
"\n",
"<table>\n",
" <thead>\n",
" <tr>\n",
" <th>Model Name</th>\n",
" <th>Model Size </th>\n",
" <th>Device </th>\n",
" <th>CPU</th>\n",
" <th>GPU</th>\n",
" </tr>\n",
" </thead>\n",
" <tr>\n",
" <td rowspan = 3>\n",
" super resolution (ESRGAN)\n",
" </td>\n",
" <td rowspan = 3>\n",
" 4.8 Mb\n",
" </td>\n",
" <td>Pixel 3</td>\n",
" <td>586.8ms*</td>\n",
" <td>128.6ms</td>\n",
" </tr>\n",
" <tr>\n",
" <td>Pixel 4</td>\n",
" <td>385.1ms*</td>\n",
" <td>130.3ms</td>\n",
" </tr>\n",
"\n",
"</table>\n",
"\n",
"**4 threads used*"
]
}
],
"metadata": {
"colab": {
"name": "super_resolution.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}