From cc8da3cc8fbc53b2408b153a7966e26cb91be86f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Dec 2020 07:38:30 -0800 Subject: [PATCH] Add super resolution TFLite sample notebook PiperOrigin-RevId: 345229956 Change-Id: Ib2b28bfc383bf4767b2791c02c510443660f3bbb --- tensorflow/lite/g3doc/_book.yaml | 2 + .../models/super_resolution/overview.ipynb | 345 ++++++++++++++++++ 2 files changed, 347 insertions(+) create mode 100644 tensorflow/lite/g3doc/models/super_resolution/overview.ipynb diff --git a/tensorflow/lite/g3doc/_book.yaml b/tensorflow/lite/g3doc/_book.yaml index f3fd7db2509..89644d4324d 100644 --- a/tensorflow/lite/g3doc/_book.yaml +++ b/tensorflow/lite/g3doc/_book.yaml @@ -216,6 +216,8 @@ upper_tabs: path: /lite/models/segmentation/overview - title: "Style transfer" path: /lite/models/style_transfer/overview + - title: "Super resolution" + path: /lite/models/super_resolution/overview - heading: "Text" - title: "BERT Question Answer" path: /lite/models/bert_qa/overview diff --git a/tensorflow/lite/g3doc/models/super_resolution/overview.ipynb b/tensorflow/lite/g3doc/models/super_resolution/overview.ipynb new file mode 100644 index 00000000000..bcea0114f47 --- /dev/null +++ b/tensorflow/lite/g3doc/models/super_resolution/overview.ipynb @@ -0,0 +1,345 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "JfOIB1KdkbYW" + }, + "source": [ + "##### Copyright 2020 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Ojb0aXCmBgo7" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "M9Y4JZ0ZGoE4" + }, + "source": [ + "# Super resolution with TensorFlow Lite" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + " \n", + " See TF Hub model\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-uF3N4BbaMvA" + }, + "source": [ + "## Overview" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "isbXET4vVHfu" + }, + "source": [ + "The task of recovering a high resolution (HR) image from its low resolution counterpart is commonly referred to as Single Image Super Resolution (SISR). \n", + "\n", + "The model used here is ESRGAN\n", + "([ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks](https://arxiv.org/abs/1809.00219)). And we are going to use TensorFlow Lite to run inference on the pretrained model.\n", + "\n", + "The TFLite model is converted from this\n", + "[implementation](https://tfhub.dev/captain-pool/esrgan-tf2/1) hosted on TF Hub. Note that the model we converted upsamples a 50x50 low resolution image to a 200x200 high resolution image (scale factor=4). If you want a different input size or scale factor, you need to re-convert or re-train the original model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2dQlTqiffuoU" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qKyMtsGqu3zH" + }, + "source": [ + "Let's install required libraries first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7YTT1Rxsw3A9" + }, + "outputs": [], + "source": [ + "!pip install matplotlib tensorflow tensorflow-hub" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Clz5Kl97FswD" + }, + "source": [ + "Import dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2xh1kvGEBjuP" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_hub as hub\n", + "import matplotlib.pyplot as plt\n", + "print(tf.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i5miVfL4kxTA" + }, + "source": [ + "Download and convert the ESRGAN model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X5PvXIXRwvHj" + }, + "outputs": [], + "source": [ + "model = hub.load(\"https://tfhub.dev/captain-pool/esrgan-tf2/1\")\n", + "concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n", + "concrete_func.inputs[0].set_shape([1, 50, 50, 3])\n", + "converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", + "tflite_model = converter.convert()\n", + "\n", + "# Save the TF Lite model.\n", + "with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:\n", + " f.write(tflite_model)\n", + "\n", + "esrgan_model_path = './ESRGAN.tflite'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jH5-xPkyUEqt" + }, + "source": [ + "Download a test image (insect head)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "suWiStTWgK6e" + }, + "outputs": [], + "source": [ + "test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rgQ4qRuFNpyW" + }, + "source": [ + "## Generate a super resolution image using TensorFlow Lite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "J9FV4btf02-2" + }, + "outputs": [], + "source": [ + "lr = tf.io.read_file(test_img_path)\n", + "lr = tf.image.decode_jpeg(lr)\n", + "lr = tf.expand_dims(lr, axis=0)\n", + "lr = tf.cast(lr, tf.float32)\n", + "\n", + "# Load TFLite model and allocate tensors.\n", + "interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)\n", + "interpreter.allocate_tensors()\n", + "\n", + "# Get input and output tensors.\n", + "input_details = interpreter.get_input_details()\n", + "output_details = interpreter.get_output_details()\n", + "\n", + "# Run the model\n", + "interpreter.set_tensor(input_details[0]['index'], lr)\n", + "interpreter.invoke()\n", + "\n", + "# Extract the output and postprocess it\n", + "output_data = interpreter.get_tensor(output_details[0]['index'])\n", + "sr = tf.squeeze(output_data, axis=0)\n", + "sr = tf.clip_by_value(sr, 0, 255)\n", + "sr = tf.round(sr)\n", + "sr = tf.cast(sr, tf.uint8)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EwddQrDUNQGO" + }, + "source": [ + "## Visualize the result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aasKuozt1gNd" + }, + "outputs": [], + "source": [ + "lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)\n", + "plt.figure(figsize = (1, 1))\n", + "plt.title('LR')\n", + "plt.imshow(lr.numpy());\n", + "\n", + "plt.figure(figsize=(10, 4))\n", + "plt.subplot(1, 2, 1) \n", + "plt.title(f'ESRGAN (x4)')\n", + "plt.imshow(sr.numpy());\n", + "\n", + "bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)\n", + "bicubic = tf.cast(bicubic, tf.uint8)\n", + "plt.subplot(1, 2, 2) \n", + "plt.title('Bicubic')\n", + "plt.imshow(bicubic.numpy());" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0kb-fkogObjq" + }, + "source": [ + "## Performance Benchmarks" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tNzdgpqTy5P3" + }, + "source": [ + "Performance benchmark numbers are generated with the tool\n", + "[described here](https://www.tensorflow.org/lite/performance/benchmarks).\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Model NameModel Size Device CPUGPU
\n", + " super resolution (ESRGAN)\n", + " \n", + " 4.8 Mb\n", + " Pixel 3586.8ms*128.6ms
Pixel 4385.1ms*130.3ms
\n", + "\n", + "**4 threads used*" + ] + } + ], + "metadata": { + "colab": { + "name": "super_resolution.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}