Add super resolution TFLite sample notebook
PiperOrigin-RevId: 345229956 Change-Id: Ib2b28bfc383bf4767b2791c02c510443660f3bbb
This commit is contained in:
parent
ec0d1fd14d
commit
cc8da3cc8f
@ -216,6 +216,8 @@ upper_tabs:
|
||||
path: /lite/models/segmentation/overview
|
||||
- title: "Style transfer"
|
||||
path: /lite/models/style_transfer/overview
|
||||
- title: "Super resolution"
|
||||
path: /lite/models/super_resolution/overview
|
||||
- heading: "Text"
|
||||
- title: "BERT Question Answer"
|
||||
path: /lite/models/bert_qa/overview
|
||||
|
345
tensorflow/lite/g3doc/models/super_resolution/overview.ipynb
Normal file
345
tensorflow/lite/g3doc/models/super_resolution/overview.ipynb
Normal file
@ -0,0 +1,345 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "JfOIB1KdkbYW"
|
||||
},
|
||||
"source": [
|
||||
"##### Copyright 2020 The TensorFlow Authors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "Ojb0aXCmBgo7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "M9Y4JZ0ZGoE4"
|
||||
},
|
||||
"source": [
|
||||
"# Super resolution with TensorFlow Lite"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://www.tensorflow.org/lite/models/super_resolution/overview\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models/super_resolution/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models/super_resolution/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/models/super_resolution/overview.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
|
||||
" </td>\n",
|
||||
" <td>\n",
|
||||
" <a href=\"https://tfhub.dev/captain-pool/esrgan-tf2/1\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub model</a>\n",
|
||||
" </td>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "-uF3N4BbaMvA"
|
||||
},
|
||||
"source": [
|
||||
"## Overview"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "isbXET4vVHfu"
|
||||
},
|
||||
"source": [
|
||||
"The task of recovering a high resolution (HR) image from its low resolution counterpart is commonly referred to as Single Image Super Resolution (SISR). \n",
|
||||
"\n",
|
||||
"The model used here is ESRGAN\n",
|
||||
"([ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks](https://arxiv.org/abs/1809.00219)). And we are going to use TensorFlow Lite to run inference on the pretrained model.\n",
|
||||
"\n",
|
||||
"The TFLite model is converted from this\n",
|
||||
"[implementation](https://tfhub.dev/captain-pool/esrgan-tf2/1) hosted on TF Hub. Note that the model we converted upsamples a 50x50 low resolution image to a 200x200 high resolution image (scale factor=4). If you want a different input size or scale factor, you need to re-convert or re-train the original model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2dQlTqiffuoU"
|
||||
},
|
||||
"source": [
|
||||
"## Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "qKyMtsGqu3zH"
|
||||
},
|
||||
"source": [
|
||||
"Let's install required libraries first."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "7YTT1Rxsw3A9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install matplotlib tensorflow tensorflow-hub"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Clz5Kl97FswD"
|
||||
},
|
||||
"source": [
|
||||
"Import dependencies."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "2xh1kvGEBjuP"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorflow as tf\n",
|
||||
"import tensorflow_hub as hub\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"print(tf.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "i5miVfL4kxTA"
|
||||
},
|
||||
"source": [
|
||||
"Download and convert the ESRGAN model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "X5PvXIXRwvHj"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = hub.load(\"https://tfhub.dev/captain-pool/esrgan-tf2/1\")\n",
|
||||
"concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n",
|
||||
"concrete_func.inputs[0].set_shape([1, 50, 50, 3])\n",
|
||||
"converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])\n",
|
||||
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
|
||||
"tflite_model = converter.convert()\n",
|
||||
"\n",
|
||||
"# Save the TF Lite model.\n",
|
||||
"with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:\n",
|
||||
" f.write(tflite_model)\n",
|
||||
"\n",
|
||||
"esrgan_model_path = './ESRGAN.tflite'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "jH5-xPkyUEqt"
|
||||
},
|
||||
"source": [
|
||||
"Download a test image (insect head)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "suWiStTWgK6e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_img_path = tf.keras.utils.get_file('lr.jpg', 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "rgQ4qRuFNpyW"
|
||||
},
|
||||
"source": [
|
||||
"## Generate a super resolution image using TensorFlow Lite"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "J9FV4btf02-2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"lr = tf.io.read_file(test_img_path)\n",
|
||||
"lr = tf.image.decode_jpeg(lr)\n",
|
||||
"lr = tf.expand_dims(lr, axis=0)\n",
|
||||
"lr = tf.cast(lr, tf.float32)\n",
|
||||
"\n",
|
||||
"# Load TFLite model and allocate tensors.\n",
|
||||
"interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)\n",
|
||||
"interpreter.allocate_tensors()\n",
|
||||
"\n",
|
||||
"# Get input and output tensors.\n",
|
||||
"input_details = interpreter.get_input_details()\n",
|
||||
"output_details = interpreter.get_output_details()\n",
|
||||
"\n",
|
||||
"# Run the model\n",
|
||||
"interpreter.set_tensor(input_details[0]['index'], lr)\n",
|
||||
"interpreter.invoke()\n",
|
||||
"\n",
|
||||
"# Extract the output and postprocess it\n",
|
||||
"output_data = interpreter.get_tensor(output_details[0]['index'])\n",
|
||||
"sr = tf.squeeze(output_data, axis=0)\n",
|
||||
"sr = tf.clip_by_value(sr, 0, 255)\n",
|
||||
"sr = tf.round(sr)\n",
|
||||
"sr = tf.cast(sr, tf.uint8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "EwddQrDUNQGO"
|
||||
},
|
||||
"source": [
|
||||
"## Visualize the result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "aasKuozt1gNd"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)\n",
|
||||
"plt.figure(figsize = (1, 1))\n",
|
||||
"plt.title('LR')\n",
|
||||
"plt.imshow(lr.numpy());\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(10, 4))\n",
|
||||
"plt.subplot(1, 2, 1) \n",
|
||||
"plt.title(f'ESRGAN (x4)')\n",
|
||||
"plt.imshow(sr.numpy());\n",
|
||||
"\n",
|
||||
"bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)\n",
|
||||
"bicubic = tf.cast(bicubic, tf.uint8)\n",
|
||||
"plt.subplot(1, 2, 2) \n",
|
||||
"plt.title('Bicubic')\n",
|
||||
"plt.imshow(bicubic.numpy());"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "0kb-fkogObjq"
|
||||
},
|
||||
"source": [
|
||||
"## Performance Benchmarks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "tNzdgpqTy5P3"
|
||||
},
|
||||
"source": [
|
||||
"Performance benchmark numbers are generated with the tool\n",
|
||||
"[described here](https://www.tensorflow.org/lite/performance/benchmarks).\n",
|
||||
"\n",
|
||||
"<table>\n",
|
||||
" <thead>\n",
|
||||
" <tr>\n",
|
||||
" <th>Model Name</th>\n",
|
||||
" <th>Model Size </th>\n",
|
||||
" <th>Device </th>\n",
|
||||
" <th>CPU</th>\n",
|
||||
" <th>GPU</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tr>\n",
|
||||
" <td rowspan = 3>\n",
|
||||
" super resolution (ESRGAN)\n",
|
||||
" </td>\n",
|
||||
" <td rowspan = 3>\n",
|
||||
" 4.8 Mb\n",
|
||||
" </td>\n",
|
||||
" <td>Pixel 3</td>\n",
|
||||
" <td>586.8ms*</td>\n",
|
||||
" <td>128.6ms</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>Pixel 4</td>\n",
|
||||
" <td>385.1ms*</td>\n",
|
||||
" <td>130.3ms</td>\n",
|
||||
" </tr>\n",
|
||||
"\n",
|
||||
"</table>\n",
|
||||
"\n",
|
||||
"**4 threads used*"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "super_resolution.ipynb",
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
Loading…
Reference in New Issue
Block a user