Two speed test notebooks: MNIST training loop and Collatz.

PiperOrigin-RevId: 204757439
This commit is contained in:
Dan Moldovan 2018-07-16 09:48:27 -07:00 committed by TensorFlower Gardener
parent 97ae13e08d
commit b70a39b4e6
2 changed files with 469 additions and 95 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,43 @@
{ {
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "etTmZVFN8fYO"
},
"source": [
"This notebook runs a basic speed test for a short training loop of a neural network training on the MNIST dataset."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eqOvRhOz8SWs"
},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "nHY0tntRizGb"
},
"outputs": [],
"source": [
"!pip install -U -q tf-nightly"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": 0,
@ -15,18 +53,93 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import gzip\n",
"import os\n", "import os\n",
"import shutil\n",
"import time\n", "import time\n",
"\n", "\n",
"import numpy as np\n", "import numpy as np\n",
"import six\n", "import six\n",
"from six.moves import urllib\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"\n", "\n",
"from tensorflow.contrib import autograph\n", "from tensorflow.contrib import autograph as ag\n",
"from tensorflow.contrib.eager.python import tfe\n", "from tensorflow.contrib.eager.python import tfe\n",
"from tensorflow.python.eager import context\n" "from tensorflow.python.eager import context\n"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PZWxEJFM9A7b"
},
"source": [
"### Testing boilerplate"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "kfZk9EFZ5TeQ"
},
"outputs": [],
"source": [
"# Test-only parameters. Test checks successful completion not correctness. \n",
"burn_ins = 1\n",
"trials = 1\n",
"max_steps = 2\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "k0GKbZBJ9Gt9"
},
"source": [
"### Speed test configuration"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "gWXV8WHn43iZ"
},
"outputs": [],
"source": [
"#@test {\"skip\": true} \n",
"burn_ins = 3\n",
"trials = 10\n",
"max_steps = 500\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kZV_3pGy8033"
},
"source": [
"### Data source setup"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": 0,
@ -42,12 +155,6 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import gzip\n",
"import shutil\n",
"\n",
"from six.moves import urllib\n",
"\n",
"\n",
"def download(directory, filename):\n", "def download(directory, filename):\n",
" filepath = os.path.join(directory, filename)\n", " filepath = os.path.join(directory, filename)\n",
" if tf.gfile.Exists(filepath):\n", " if tf.gfile.Exists(filepath):\n",
@ -107,6 +214,16 @@
" return ds\n" " return ds\n"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "qzkZyZcS9THu"
},
"source": [
"### Keras model definition"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 0, "execution_count": 0,
@ -131,48 +248,6 @@
" return model\n" " return model\n"
] ]
}, },
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "kfZk9EFZ5TeQ"
},
"outputs": [],
"source": [
"# Test-only parameters. Test checks successful completion not correctness. \n",
"burn_ins = 1\n",
"trials = 1\n",
"max_steps = 2\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "gWXV8WHn43iZ"
},
"outputs": [],
"source": [
"#@test {\"skip\": true} \n",
"burn_ins = 3\n",
"trials = 10\n",
"max_steps = 500\n"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
@ -180,7 +255,7 @@
"id": "DXt4GoTxtvn2" "id": "DXt4GoTxtvn2"
}, },
"source": [ "source": [
"# Autograph" "# AutoGraph"
] ]
}, },
{ {
@ -274,10 +349,10 @@
" test_losses = []\n", " test_losses = []\n",
" train_accuracies = []\n", " train_accuracies = []\n",
" test_accuracies = []\n", " test_accuracies = []\n",
" autograph.set_element_type(train_losses, tf.float32)\n", " ag.set_element_type(train_losses, tf.float32)\n",
" autograph.set_element_type(test_losses, tf.float32)\n", " ag.set_element_type(test_losses, tf.float32)\n",
" autograph.set_element_type(train_accuracies, tf.float32)\n", " ag.set_element_type(train_accuracies, tf.float32)\n",
" autograph.set_element_type(test_accuracies, tf.float32)\n", " ag.set_element_type(test_accuracies, tf.float32)\n",
"\n", "\n",
" i = tf.constant(0)\n", " i = tf.constant(0)\n",
" while i \u003c hp.max_steps:\n", " while i \u003c hp.max_steps:\n",
@ -292,26 +367,26 @@
" test_accuracies.append(step_test_accuracy)\n", " test_accuracies.append(step_test_accuracy)\n",
"\n", "\n",
" i += 1\n", " i += 1\n",
" return (autograph.stack(train_losses), autograph.stack(test_losses),\n", " return (ag.stack(train_losses), ag.stack(test_losses),\n",
" autograph.stack(train_accuracies), autograph.stack(test_accuracies))\n" " ag.stack(train_accuracies), ag.stack(test_accuracies))\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 0,
"metadata": { "metadata": {
"colab": { "colab": {
"autoexec": { "autoexec": {
"startup": false, "startup": false,
"wait_interval": 0 "wait_interval": 0
}, },
"height": 220 "height": 215
}, },
"colab_type": "code", "colab_type": "code",
"executionInfo": { "executionInfo": {
"elapsed": 12896, "elapsed": 12156,
"status": "ok", "status": "ok",
"timestamp": 1531534784996, "timestamp": 1531752050611,
"user": { "user": {
"displayName": "", "displayName": "",
"photoUrl": "", "photoUrl": "",
@ -320,24 +395,24 @@
"user_tz": 240 "user_tz": 240
}, },
"id": "K1m8TwOKjdNd", "id": "K1m8TwOKjdNd",
"outputId": "2ee3ff78-9aae-4fac-a1fd-32bf3b2f18f4" "outputId": "bd5746f2-bf91-44aa-9eff-38eb11ced33f"
}, },
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"('Duration:', 0.7540969848632812)\n", "('Duration:', 0.6226680278778076)\n",
"('Duration:', 0.7829370498657227)\n", "('Duration:', 0.6082069873809814)\n",
"('Duration:', 0.7111489772796631)\n", "('Duration:', 0.6223258972167969)\n",
"('Duration:', 0.6126768589019775)\n", "('Duration:', 0.6176440715789795)\n",
"('Duration:', 0.6143529415130615)\n", "('Duration:', 0.6309840679168701)\n",
"('Duration:', 0.6174650192260742)\n", "('Duration:', 0.6180410385131836)\n",
"('Duration:', 0.6425611972808838)\n", "('Duration:', 0.6219630241394043)\n",
"('Duration:', 0.6188449859619141)\n", "('Duration:', 0.6183009147644043)\n",
"('Duration:', 0.6388339996337891)\n", "('Duration:', 0.6176400184631348)\n",
"('Duration:', 0.6235959529876709)\n", "('Duration:', 0.6476900577545166)\n",
"('Mean duration:', 0.66165139675140383, '+/-', 0.060382254849383483)\n" "('Mean duration:', 0.62254641056060789, '+/-', 0.0099792188690656976)\n"
] ]
} }
], ],
@ -350,7 +425,7 @@
" )\n", " )\n",
" train_ds = setup_mnist_data(True, hp, 500)\n", " train_ds = setup_mnist_data(True, hp, 500)\n",
" test_ds = setup_mnist_data(False, hp, 100)\n", " test_ds = setup_mnist_data(False, hp, 100)\n",
" tf_train = autograph.to_graph(train)\n", " tf_train = ag.to_graph(train)\n",
" losses = tf_train(train_ds, test_ds, hp)\n", " losses = tf_train(train_ds, test_ds, hp)\n",
"\n", "\n",
" with tf.Session() as sess:\n", " with tf.Session() as sess:\n",
@ -458,20 +533,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 0,
"metadata": { "metadata": {
"colab": { "colab": {
"autoexec": { "autoexec": {
"startup": false, "startup": false,
"wait_interval": 0 "wait_interval": 0
}, },
"height": 220 "height": 215
}, },
"colab_type": "code", "colab_type": "code",
"executionInfo": { "executionInfo": {
"elapsed": 53945, "elapsed": 52499,
"status": "ok", "status": "ok",
"timestamp": 1531534839296, "timestamp": 1531752103279,
"user": { "user": {
"displayName": "", "displayName": "",
"photoUrl": "", "photoUrl": "",
@ -480,24 +555,24 @@
"user_tz": 240 "user_tz": 240
}, },
"id": "plv_yrn_t8Dy", "id": "plv_yrn_t8Dy",
"outputId": "93f2f468-7191-430c-88d2-948b4ce1ea06" "outputId": "55d5ab3d-252d-48ba-8fb4-20ec3c3e6d00"
}, },
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"('Duration:', 4.146992206573486)\n", "('Duration:', 3.9973549842834473)\n",
"('Duration:', 4.107615947723389)\n", "('Duration:', 4.018772125244141)\n",
"('Duration:', 4.07602596282959)\n", "('Duration:', 3.9740989208221436)\n",
"('Duration:', 4.113464832305908)\n", "('Duration:', 3.9922947883605957)\n",
"('Duration:', 4.100026845932007)\n", "('Duration:', 3.9795801639556885)\n",
"('Duration:', 4.145462989807129)\n", "('Duration:', 3.966722011566162)\n",
"('Duration:', 4.11216402053833)\n", "('Duration:', 3.986541986465454)\n",
"('Duration:', 4.094243049621582)\n", "('Duration:', 3.992305040359497)\n",
"('Duration:', 4.095034837722778)\n", "('Duration:', 4.012261867523193)\n",
"('Duration:', 4.11162805557251)\n", "('Duration:', 4.004716157913208)\n",
"('Mean duration:', 4.1102658748626713, '+/-', 0.020919605607527668)\n" "('Mean duration:', 3.9924648046493529, '+/-', 0.015681688635624851)\n"
] ]
} }
], ],
@ -535,13 +610,13 @@
], ],
"metadata": { "metadata": {
"colab": { "colab": {
"collapsed_sections": [], "collapsed_sections": [
"eqOvRhOz8SWs",
"PZWxEJFM9A7b",
"kZV_3pGy8033"
],
"default_view": {}, "default_view": {},
"last_runtime": { "name": "Autograph vs. Eager MNIST speed test",
"build_target": "",
"kind": "local"
},
"name": "Autograph vs. Eager MNIST benchmark",
"provenance": [ "provenance": [
{ {
"file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD", "file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD",