Two speed test notebooks: MNIST training loop and Collatz.
PiperOrigin-RevId: 204757439
This commit is contained in:
parent
97ae13e08d
commit
b70a39b4e6
File diff suppressed because one or more lines are too long
@ -1,5 +1,43 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "etTmZVFN8fYO"
|
||||
},
|
||||
"source": [
|
||||
"This notebook runs a basic speed test for a short training loop of a neural network training on the MNIST dataset."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "eqOvRhOz8SWs"
|
||||
},
|
||||
"source": [
|
||||
"### Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "nHY0tntRizGb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -U -q tf-nightly"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
@ -15,18 +53,93 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gzip\n",
|
||||
"import os\n",
|
||||
"import shutil\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import six\n",
|
||||
"from six.moves import urllib\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"from tensorflow.contrib import autograph\n",
|
||||
"from tensorflow.contrib import autograph as ag\n",
|
||||
"from tensorflow.contrib.eager.python import tfe\n",
|
||||
"from tensorflow.python.eager import context\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "PZWxEJFM9A7b"
|
||||
},
|
||||
"source": [
|
||||
"### Testing boilerplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "kfZk9EFZ5TeQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test-only parameters. Test checks successful completion not correctness. \n",
|
||||
"burn_ins = 1\n",
|
||||
"trials = 1\n",
|
||||
"max_steps = 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "k0GKbZBJ9Gt9"
|
||||
},
|
||||
"source": [
|
||||
"### Speed test configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "gWXV8WHn43iZ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@test {\"skip\": true} \n",
|
||||
"burn_ins = 3\n",
|
||||
"trials = 10\n",
|
||||
"max_steps = 500\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "kZV_3pGy8033"
|
||||
},
|
||||
"source": [
|
||||
"### Data source setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
@ -42,12 +155,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gzip\n",
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"from six.moves import urllib\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def download(directory, filename):\n",
|
||||
" filepath = os.path.join(directory, filename)\n",
|
||||
" if tf.gfile.Exists(filepath):\n",
|
||||
@ -107,6 +214,16 @@
|
||||
" return ds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "qzkZyZcS9THu"
|
||||
},
|
||||
"source": [
|
||||
"### Keras model definition"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
@ -131,48 +248,6 @@
|
||||
" return model\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "kfZk9EFZ5TeQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test-only parameters. Test checks successful completion not correctness. \n",
|
||||
"burn_ins = 1\n",
|
||||
"trials = 1\n",
|
||||
"max_steps = 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "gWXV8WHn43iZ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@test {\"skip\": true} \n",
|
||||
"burn_ins = 3\n",
|
||||
"trials = 10\n",
|
||||
"max_steps = 500\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
@ -180,7 +255,7 @@
|
||||
"id": "DXt4GoTxtvn2"
|
||||
},
|
||||
"source": [
|
||||
"# Autograph"
|
||||
"# AutoGraph"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -274,10 +349,10 @@
|
||||
" test_losses = []\n",
|
||||
" train_accuracies = []\n",
|
||||
" test_accuracies = []\n",
|
||||
" autograph.set_element_type(train_losses, tf.float32)\n",
|
||||
" autograph.set_element_type(test_losses, tf.float32)\n",
|
||||
" autograph.set_element_type(train_accuracies, tf.float32)\n",
|
||||
" autograph.set_element_type(test_accuracies, tf.float32)\n",
|
||||
" ag.set_element_type(train_losses, tf.float32)\n",
|
||||
" ag.set_element_type(test_losses, tf.float32)\n",
|
||||
" ag.set_element_type(train_accuracies, tf.float32)\n",
|
||||
" ag.set_element_type(test_accuracies, tf.float32)\n",
|
||||
"\n",
|
||||
" i = tf.constant(0)\n",
|
||||
" while i \u003c hp.max_steps:\n",
|
||||
@ -292,26 +367,26 @@
|
||||
" test_accuracies.append(step_test_accuracy)\n",
|
||||
"\n",
|
||||
" i += 1\n",
|
||||
" return (autograph.stack(train_losses), autograph.stack(test_losses),\n",
|
||||
" autograph.stack(train_accuracies), autograph.stack(test_accuracies))\n"
|
||||
" return (ag.stack(train_losses), ag.stack(test_losses),\n",
|
||||
" ag.stack(train_accuracies), ag.stack(test_accuracies))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
},
|
||||
"height": 220
|
||||
"height": 215
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 12896,
|
||||
"elapsed": 12156,
|
||||
"status": "ok",
|
||||
"timestamp": 1531534784996,
|
||||
"timestamp": 1531752050611,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
@ -320,24 +395,24 @@
|
||||
"user_tz": 240
|
||||
},
|
||||
"id": "K1m8TwOKjdNd",
|
||||
"outputId": "2ee3ff78-9aae-4fac-a1fd-32bf3b2f18f4"
|
||||
"outputId": "bd5746f2-bf91-44aa-9eff-38eb11ced33f"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"('Duration:', 0.7540969848632812)\n",
|
||||
"('Duration:', 0.7829370498657227)\n",
|
||||
"('Duration:', 0.7111489772796631)\n",
|
||||
"('Duration:', 0.6126768589019775)\n",
|
||||
"('Duration:', 0.6143529415130615)\n",
|
||||
"('Duration:', 0.6174650192260742)\n",
|
||||
"('Duration:', 0.6425611972808838)\n",
|
||||
"('Duration:', 0.6188449859619141)\n",
|
||||
"('Duration:', 0.6388339996337891)\n",
|
||||
"('Duration:', 0.6235959529876709)\n",
|
||||
"('Mean duration:', 0.66165139675140383, '+/-', 0.060382254849383483)\n"
|
||||
"('Duration:', 0.6226680278778076)\n",
|
||||
"('Duration:', 0.6082069873809814)\n",
|
||||
"('Duration:', 0.6223258972167969)\n",
|
||||
"('Duration:', 0.6176440715789795)\n",
|
||||
"('Duration:', 0.6309840679168701)\n",
|
||||
"('Duration:', 0.6180410385131836)\n",
|
||||
"('Duration:', 0.6219630241394043)\n",
|
||||
"('Duration:', 0.6183009147644043)\n",
|
||||
"('Duration:', 0.6176400184631348)\n",
|
||||
"('Duration:', 0.6476900577545166)\n",
|
||||
"('Mean duration:', 0.62254641056060789, '+/-', 0.0099792188690656976)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -350,7 +425,7 @@
|
||||
" )\n",
|
||||
" train_ds = setup_mnist_data(True, hp, 500)\n",
|
||||
" test_ds = setup_mnist_data(False, hp, 100)\n",
|
||||
" tf_train = autograph.to_graph(train)\n",
|
||||
" tf_train = ag.to_graph(train)\n",
|
||||
" losses = tf_train(train_ds, test_ds, hp)\n",
|
||||
"\n",
|
||||
" with tf.Session() as sess:\n",
|
||||
@ -458,20 +533,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
},
|
||||
"height": 220
|
||||
"height": 215
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 53945,
|
||||
"elapsed": 52499,
|
||||
"status": "ok",
|
||||
"timestamp": 1531534839296,
|
||||
"timestamp": 1531752103279,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
@ -480,24 +555,24 @@
|
||||
"user_tz": 240
|
||||
},
|
||||
"id": "plv_yrn_t8Dy",
|
||||
"outputId": "93f2f468-7191-430c-88d2-948b4ce1ea06"
|
||||
"outputId": "55d5ab3d-252d-48ba-8fb4-20ec3c3e6d00"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"('Duration:', 4.146992206573486)\n",
|
||||
"('Duration:', 4.107615947723389)\n",
|
||||
"('Duration:', 4.07602596282959)\n",
|
||||
"('Duration:', 4.113464832305908)\n",
|
||||
"('Duration:', 4.100026845932007)\n",
|
||||
"('Duration:', 4.145462989807129)\n",
|
||||
"('Duration:', 4.11216402053833)\n",
|
||||
"('Duration:', 4.094243049621582)\n",
|
||||
"('Duration:', 4.095034837722778)\n",
|
||||
"('Duration:', 4.11162805557251)\n",
|
||||
"('Mean duration:', 4.1102658748626713, '+/-', 0.020919605607527668)\n"
|
||||
"('Duration:', 3.9973549842834473)\n",
|
||||
"('Duration:', 4.018772125244141)\n",
|
||||
"('Duration:', 3.9740989208221436)\n",
|
||||
"('Duration:', 3.9922947883605957)\n",
|
||||
"('Duration:', 3.9795801639556885)\n",
|
||||
"('Duration:', 3.966722011566162)\n",
|
||||
"('Duration:', 3.986541986465454)\n",
|
||||
"('Duration:', 3.992305040359497)\n",
|
||||
"('Duration:', 4.012261867523193)\n",
|
||||
"('Duration:', 4.004716157913208)\n",
|
||||
"('Mean duration:', 3.9924648046493529, '+/-', 0.015681688635624851)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -535,13 +610,13 @@
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"collapsed_sections": [
|
||||
"eqOvRhOz8SWs",
|
||||
"PZWxEJFM9A7b",
|
||||
"kZV_3pGy8033"
|
||||
],
|
||||
"default_view": {},
|
||||
"last_runtime": {
|
||||
"build_target": "",
|
||||
"kind": "local"
|
||||
},
|
||||
"name": "Autograph vs. Eager MNIST benchmark",
|
||||
"name": "Autograph vs. Eager MNIST speed test",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD",
|
Loading…
Reference in New Issue
Block a user