Sync Premade and Custom estimator docs with example code.
PiperOrigin-RevId: 179404175
This commit is contained in:
parent
04df827cfd
commit
119f5d477b
@ -4,13 +4,31 @@ This document introduces custom Estimators. In particular, this document
|
|||||||
demonstrates how to create a custom @{tf.estimator.Estimator$Estimator} that
|
demonstrates how to create a custom @{tf.estimator.Estimator$Estimator} that
|
||||||
mimics the behavior of the pre-made Estimator
|
mimics the behavior of the pre-made Estimator
|
||||||
@{tf.estimator.DNNClassifier$`DNNClassifier`} in solving the Iris problem. See
|
@{tf.estimator.DNNClassifier$`DNNClassifier`} in solving the Iris problem. See
|
||||||
the @{$get_started/estimator$Pre-Made Estimators chapter} for details.
|
the @{$get_started/premade_estimators$Pre-Made Estimators chapter} for details
|
||||||
|
on the Iris problem.
|
||||||
|
|
||||||
|
To download and access the example code invoke the following two commands:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/tensorflow/models/
|
||||||
|
cd models/samples/core/get_started
|
||||||
|
```
|
||||||
|
|
||||||
|
In this document we wil be looking at
|
||||||
|
[`custom_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py).
|
||||||
|
You can run it with the following command:
|
||||||
|
|
||||||
|
```bsh
|
||||||
|
python custom_estimator.py
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are feeling impatient, feel free to compare and contrast
|
||||||
|
[`custom_estimatr.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py)
|
||||||
|
with
|
||||||
|
[`premade_estimatr.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py).
|
||||||
|
(which is in the same directory).
|
||||||
|
|
||||||
If you are feeling impatient, feel free to compare and contrast the following
|
|
||||||
full programs:
|
|
||||||
|
|
||||||
* Iris implemented with the [pre-made DNNClassifier Estimator](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py).
|
|
||||||
* Iris implemented with a [custom Estimator](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py).
|
|
||||||
|
|
||||||
## Pre-made vs. custom
|
## Pre-made vs. custom
|
||||||
|
|
||||||
@ -64,14 +82,16 @@ and a logits output layer.
|
|||||||
|
|
||||||
## Write an Input function
|
## Write an Input function
|
||||||
|
|
||||||
In our custom Estimator implementation, we'll reuse the input function we used
|
Our custom Estimator implementation uses the same input function as our
|
||||||
in the pre-made Estimator implementation. Namely:
|
@{$get_started/premade_estimators$pre-made Estimator implementation}, from
|
||||||
|
[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py).
|
||||||
|
Namely:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def train_input_fn(features, labels, batch_size):
|
def train_input_fn(features, labels, batch_size):
|
||||||
"""An input function for training"""
|
"""An input function for training"""
|
||||||
# Convert the inputs to a Dataset.
|
# Convert the inputs to a Dataset.
|
||||||
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
|
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
|
||||||
|
|
||||||
# Shuffle, repeat, and batch the examples.
|
# Shuffle, repeat, and batch the examples.
|
||||||
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
|
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
|
||||||
@ -85,8 +105,8 @@ This input function builds an input pipeline that yields batches of
|
|||||||
|
|
||||||
## Create feature columns
|
## Create feature columns
|
||||||
|
|
||||||
<!-- TODO(markdaoust): link to feature_columns when it exists-->
|
As detailed in the @{$get_started/estimator$Premade Estimators} and
|
||||||
As detailed in @{$get_started/estimator$Premade Estimators}, you must define
|
@{$get_started/feature_columns$Feature Columns} chapters, you must define
|
||||||
your model's feature columns to specify how the model should use each feature.
|
your model's feature columns to specify how the model should use each feature.
|
||||||
Whether working with pre-made Estimators or custom Estimators, you define
|
Whether working with pre-made Estimators or custom Estimators, you define
|
||||||
feature columns in the same fashion.
|
feature columns in the same fashion.
|
||||||
@ -119,11 +139,14 @@ the input function; that is, `features` and `labels` are the handles to the
|
|||||||
data your model will use. The `mode` argument indicates whether the caller is
|
data your model will use. The `mode` argument indicates whether the caller is
|
||||||
requesting training, predicting, or evaluation.
|
requesting training, predicting, or evaluation.
|
||||||
|
|
||||||
The caller may pass `params` to an Estimator's constructor. The `params` passed
|
The caller may pass `params` to an Estimator's constructor. Any `params` passed
|
||||||
to the constructor become the `params` passed to `model_fn`.
|
to the constructor are in turn passed on to the `model_fn`. In
|
||||||
|
[`custom_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py)
|
||||||
|
the following lines create the estimator and set the params to configure the
|
||||||
|
model. This configuration step is similar to how we configured the @{tf.estimator.DNNClassifier} in
|
||||||
|
@{$get_started/premade_estimators}.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Build 2 hidden layer DNN with 10, 10 units respectively.
|
|
||||||
classifier = tf.estimator.Estimator(
|
classifier = tf.estimator.Estimator(
|
||||||
model_fn=my_model,
|
model_fn=my_model,
|
||||||
params={
|
params={
|
||||||
@ -163,7 +186,7 @@ feature columns into input for your model. For example:
|
|||||||
```
|
```
|
||||||
|
|
||||||
The preceding line applies the transformations defined by your feature columns,
|
The preceding line applies the transformations defined by your feature columns,
|
||||||
creating the input layer of our model.
|
creating the model's input layer.
|
||||||
|
|
||||||
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
<img style="height:260px"
|
<img style="height:260px"
|
||||||
@ -186,6 +209,7 @@ is connected to every node in the preceding layer. Here's the relevant code:
|
|||||||
for units in params['hidden_units']:
|
for units in params['hidden_units']:
|
||||||
net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
|
net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
|
||||||
```
|
```
|
||||||
|
|
||||||
* The `units` parameter defines the number of output neurons in a given layer.
|
* The `units` parameter defines the number of output neurons in a given layer.
|
||||||
* The `activation` parameter defines the [activation function](https://developers.google.com/machine-learning/glossary/#a) —
|
* The `activation` parameter defines the [activation function](https://developers.google.com/machine-learning/glossary/#a) —
|
||||||
[Relu](https://developers.google.com/machine-learning/glossary/#ReLU) in this
|
[Relu](https://developers.google.com/machine-learning/glossary/#ReLU) in this
|
||||||
@ -193,12 +217,11 @@ is connected to every node in the preceding layer. Here's the relevant code:
|
|||||||
|
|
||||||
The variable `net` here signifies the current top layer of the network. During
|
The variable `net` here signifies the current top layer of the network. During
|
||||||
the first iteration, `net` signifies the input layer. On each loop iteration
|
the first iteration, `net` signifies the input layer. On each loop iteration
|
||||||
`tf.layers.dense` creates a new layer, which takes the previous layer as its
|
`tf.layers.dense` creates a new layer, which takes the previous layer's output
|
||||||
input. So, the loop uses `net` to pass the previously created layer as input
|
as its input, using the variable `net`.
|
||||||
to the layer being created.
|
|
||||||
|
|
||||||
After creating two hidden layers, our network looks as follows. For
|
After creating two hidden layers, our network looks as follows. For
|
||||||
simplicity, the figure only shows four hidden units in each layer.
|
simplicity, the figure does not show all the units in each layer.
|
||||||
|
|
||||||
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
<img style="height:260px"
|
<img style="height:260px"
|
||||||
@ -235,8 +258,8 @@ The final hidden layer feeds into the output layer.
|
|||||||
|
|
||||||
When defining an output layer, the `units` parameter specifies the number of
|
When defining an output layer, the `units` parameter specifies the number of
|
||||||
outputs. So, by setting `units` to `params['n_classes']`, the model produces
|
outputs. So, by setting `units` to `params['n_classes']`, the model produces
|
||||||
one output value per class. Each element of the output vector will contains the
|
one output value per class. Each element of the output vector will contain the
|
||||||
score, or "logit", calculated to the associated class of Iris: Setosa,
|
score, or "logit", calculated for the associated class of Iris: Setosa,
|
||||||
Versicolor, or Virginica, respectively.
|
Versicolor, or Virginica, respectively.
|
||||||
|
|
||||||
Later on, these logits will be transformed into probabilities by the
|
Later on, these logits will be transformed into probabilities by the
|
||||||
@ -255,11 +278,12 @@ function looks like this:
|
|||||||
def my_model_fn(
|
def my_model_fn(
|
||||||
features, # This is batch_features from input_fn
|
features, # This is batch_features from input_fn
|
||||||
labels, # This is batch_labels from input_fn
|
labels, # This is batch_labels from input_fn
|
||||||
mode): # An instance of tf.estimator.ModeKeys, see below
|
mode, # An instance of tf.estimator.ModeKeys, see below
|
||||||
|
params): # Additional configuration
|
||||||
```
|
```
|
||||||
|
|
||||||
Focus on that third argument, mode. As the following table shows, when someone
|
Focus on that third argument, mode. As the following table shows, when someone
|
||||||
calls train, evaluate, or predict, the Estimator framework invokes your model
|
calls `train`, `evaluate`, or `predict`, the Estimator framework invokes your model
|
||||||
function with the mode parameter set as follows:
|
function with the mode parameter set as follows:
|
||||||
|
|
||||||
| Estimator method | Estimator Mode |
|
| Estimator method | Estimator Mode |
|
||||||
@ -390,8 +414,8 @@ argument of `tf.estimator.EstimatorSpec`. Here's the code:
|
|||||||
mode, loss=loss, eval_metric_ops=metrics)
|
mode, loss=loss, eval_metric_ops=metrics)
|
||||||
```
|
```
|
||||||
|
|
||||||
The @{tf.summary.scalar} will make accuracy available to TensorBoard (more on
|
The @{tf.summary.scalar} will make accuracy available to TensorBoard
|
||||||
this later).
|
in both `TRAIN` and `EVAL` modes. (More on this later).
|
||||||
|
|
||||||
### Train
|
### Train
|
||||||
|
|
||||||
@ -407,11 +431,10 @@ optimizers—feel free to experiment with them.
|
|||||||
Here is the code that builds the optimizer:
|
Here is the code that builds the optimizer:
|
||||||
|
|
||||||
``` python
|
``` python
|
||||||
# Instantiate an optimizer.
|
|
||||||
optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
|
optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, we train the model using the optimizer's
|
Next, we build the training operation using the optimizer's
|
||||||
@{tf.train.Optimizer.minimize$`minimize`} method on the loss we calculated
|
@{tf.train.Optimizer.minimize$`minimize`} method on the loss we calculated
|
||||||
earlier.
|
earlier.
|
||||||
|
|
||||||
@ -425,8 +448,6 @@ argument of `minimize`.
|
|||||||
Here's the code to train the model:
|
Here's the code to train the model:
|
||||||
|
|
||||||
``` python
|
``` python
|
||||||
# Train the model by establishing an objective, which is to
|
|
||||||
# minimize loss using that optimizer.
|
|
||||||
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
|
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -439,11 +460,7 @@ must have the following fields set:
|
|||||||
Here's our code to call `EstimatorSpec`:
|
Here's our code to call `EstimatorSpec`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Return training information.
|
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
|
||||||
return tf.estimator.EstimatorSpec(
|
|
||||||
mode=tf.estimator.ModeKeys.TRAIN,
|
|
||||||
loss=loss,
|
|
||||||
train_op=train_op)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The model function is now complete.
|
The model function is now complete.
|
||||||
@ -469,13 +486,14 @@ arguments of `DNNClassifier`; that is, the `params` dictionary lets you
|
|||||||
configure your Estimator without modifying the code in the `model_fn`.
|
configure your Estimator without modifying the code in the `model_fn`.
|
||||||
|
|
||||||
The rest of the code to train, evaluate, and generate predictions using our
|
The rest of the code to train, evaluate, and generate predictions using our
|
||||||
Estimator is the same as for the pre-made `DNNClassifier`. For example, the
|
Estimator is the same as in the
|
||||||
following line will train the model:
|
@{$get_started/premade_estimators$Premade Estimators} chapter. For
|
||||||
|
example, the following line will train the model:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Train the Model.
|
# Train the Model.
|
||||||
classifier.train(
|
classifier.train(
|
||||||
input_fn=lambda:train_input_fn(train_x, train_y, args.batch_size),
|
input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
|
||||||
steps=args.train_steps)
|
steps=args.train_steps)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ how to write the Iris classification problem in TensorFlow.
|
|||||||
|
|
||||||
Prior to reading this document, do the following:
|
Prior to reading this document, do the following:
|
||||||
|
|
||||||
* [Install TensorFlow](install/index.md).
|
* @{$install$Install TensorFlow}.
|
||||||
* If you installed TensorFlow with virtualenv or Anaconda, activate your
|
* If you installed TensorFlow with virtualenv or Anaconda, activate your
|
||||||
TensorFlow environment.
|
TensorFlow environment.
|
||||||
* To keep the data import simple, our Iris example uses Pandas. You can
|
* To keep the data import simple, our Iris example uses Pandas. You can
|
||||||
@ -28,7 +28,11 @@ Take the following steps to get the sample code for this program:
|
|||||||
|
|
||||||
`cd models/samples/core/get_started/`
|
`cd models/samples/core/get_started/`
|
||||||
|
|
||||||
The program described in this document is called `premade_estimator.py`.
|
The program described in this document is
|
||||||
|
[`premade_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py).
|
||||||
|
This program uses
|
||||||
|
[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py)
|
||||||
|
To fetch its training data.
|
||||||
|
|
||||||
### Running the program
|
### Running the program
|
||||||
|
|
||||||
@ -38,15 +42,15 @@ You run TensorFlow programs as you would run any Python program. For example:
|
|||||||
python premade_estimator.py
|
python premade_estimator.py
|
||||||
```
|
```
|
||||||
|
|
||||||
The program should output training logs and some predictions against a test
|
The program should output training logs followed by some predictions against
|
||||||
set. For example, the first line in the following output shows that the model
|
the test set. For example, the first line in the following output shows that
|
||||||
thinks there is a 99.6% chance that the first example in the test set is a
|
the model thinks there is a 99.6% chance that the first example in the test
|
||||||
Sentosa. Since the test set `expected "Setosa"`, this appears to be a good
|
set is a Setosa. Since the test set `expected "Setosa"`, this appears to be
|
||||||
prediction.
|
a good prediction.
|
||||||
|
|
||||||
``` None
|
``` None
|
||||||
...
|
...
|
||||||
Prediction is "Sentosa" (99.6%), expected "Setosa"
|
Prediction is "Setosa" (99.6%), expected "Setosa"
|
||||||
|
|
||||||
Prediction is "Versicolor" (99.8%), expected "Versicolor"
|
Prediction is "Versicolor" (99.8%), expected "Versicolor"
|
||||||
|
|
||||||
@ -76,12 +80,12 @@ The TensorFlow Programming Environment
|
|||||||
|
|
||||||
We strongly recommend writing TensorFlow programs with the following APIs:
|
We strongly recommend writing TensorFlow programs with the following APIs:
|
||||||
|
|
||||||
* Estimators, which represent a complete model. The Estimator API provides
|
* @{tf.estimator$Estimators}, which represent a complete model.
|
||||||
methods to train the model, to judge the model's accuracy, and to generate
|
The Estimator API provides methods to train the model, to judge the model's
|
||||||
predictions.
|
accuracy, and to generate predictions.
|
||||||
* Datasets, which build a data input pipeline. The Dataset API has methods to
|
* @{$get_started/datasets_quickstart$Datasets}, which build a data input
|
||||||
load and manipulate data, and feed it into your model. The Datasets API meshes
|
pipeline. The Dataset API has methods to load and manipulate data, and feed
|
||||||
well with the Estimators API.
|
it into your model. The Datasets API meshes well with the Estimators API.
|
||||||
|
|
||||||
## Classifying irises: an overview
|
## Classifying irises: an overview
|
||||||
|
|
||||||
@ -130,7 +134,7 @@ The following table shows three examples in the data set:
|
|||||||
|
|
||||||
|sepal length | sepal width | petal length | petal width| species (label) |
|
|sepal length | sepal width | petal length | petal width| species (label) |
|
||||||
|------------:|------------:|-------------:|-----------:|:---------------:|
|
|------------:|------------:|-------------:|-----------:|:---------------:|
|
||||||
| 5.1 | 3.3 | 1.7 | 0.5 | 0 (Sentosa) |
|
| 5.1 | 3.3 | 1.7 | 0.5 | 0 (Setosa) |
|
||||||
| 5.0 | 2.3 | 3.3 | 1.0 | 1 (versicolor)|
|
| 5.0 | 2.3 | 3.3 | 1.0 | 1 (versicolor)|
|
||||||
| 6.4 | 2.8 | 5.6 | 2.2 | 2 (virginica) |
|
| 6.4 | 2.8 | 5.6 | 2.2 | 2 (virginica) |
|
||||||
|
|
||||||
@ -145,11 +149,10 @@ topology:
|
|||||||
The following figure illustrates the features, hidden layers, and predictions
|
The following figure illustrates the features, hidden layers, and predictions
|
||||||
(not all of the nodes in the hidden layers are shown):
|
(not all of the nodes in the hidden layers are shown):
|
||||||
|
|
||||||
|
|
||||||
<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
<img style="width:100%"
|
<img style="width:100%"
|
||||||
alt="A diagram of the network architecture: Inputs, 2 hidden layers, and outputs"
|
alt="A diagram of the network architecture: Inputs, 2 hidden layers, and outputs"
|
||||||
src="../images/iris_model.png">
|
src="../images/custom_estimators/full_network.png">
|
||||||
</div>
|
</div>
|
||||||
<div style="text-align: center">
|
<div style="text-align: center">
|
||||||
The Model.
|
The Model.
|
||||||
@ -252,9 +255,11 @@ The Dataset API can handle a lot of common cases for you. For example,
|
|||||||
using the Dataset API, you can easily read in records from a large collection
|
using the Dataset API, you can easily read in records from a large collection
|
||||||
of files in parallel and join them into a single stream.
|
of files in parallel and join them into a single stream.
|
||||||
|
|
||||||
To keep things simple in this example we are going to load the data with pandas, and build our input pipeline from this in-memory data.
|
To keep things simple in this example we are going to load the data with pandas,
|
||||||
|
and build our input pipeline from this in-memory data.
|
||||||
|
|
||||||
Here is the input function used for training in this program:
|
Here is the input function used for training in this program, which is available
|
||||||
|
in [`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py):
|
||||||
|
|
||||||
``` python
|
``` python
|
||||||
def train_input_fn(features, labels, batch_size):
|
def train_input_fn(features, labels, batch_size):
|
||||||
@ -272,14 +277,14 @@ def train_input_fn(features, labels, batch_size):
|
|||||||
## Define the Feature Columns
|
## Define the Feature Columns
|
||||||
|
|
||||||
A [**Feature Column**](https://developers.google.com/machine-learning/glossary/#feature_columns)
|
A [**Feature Column**](https://developers.google.com/machine-learning/glossary/#feature_columns)
|
||||||
is an object describing how the model should use raw input features from the
|
is an object describing how the model should use raw input data from the
|
||||||
features dictionary. When you build an Estimator model, you pass it a list of
|
features dictionary. When you build an Estimator model, you pass it a list of
|
||||||
feature columns that describes each of the features you want the model to use.
|
feature columns that describes each of the features you want the model to use.
|
||||||
|
The @{tf.feature_column} module provides many options for representing data
|
||||||
These objects are created by functions in the @{tf.feature_column} module. `tf.feature_column` methods provide many different ways to represent data.
|
to the model.
|
||||||
|
|
||||||
For Iris, the 4 raw features are numeric values, so we'll build a list of
|
For Iris, the 4 raw features are numeric values, so we'll build a list of
|
||||||
feature columns, to tell the Estimator model to represent each of the four
|
feature columns to tell the Estimator model to represent each of the four
|
||||||
features as 32-bit floating-point values. Therefore, the code to create the
|
features as 32-bit floating-point values. Therefore, the code to create the
|
||||||
Feature Column is simply:
|
Feature Column is simply:
|
||||||
|
|
||||||
@ -291,7 +296,8 @@ for key in train_x.keys():
|
|||||||
```
|
```
|
||||||
|
|
||||||
Feature Columns can be far more sophisticated than those we're showing here.
|
Feature Columns can be far more sophisticated than those we're showing here.
|
||||||
<!--TODO(markdaoust) add link to feature_columns doc when it exists.-->
|
We detail feature columns @{$get_started/feature_columns$later on} in
|
||||||
|
getting started.
|
||||||
|
|
||||||
Now that we have the description of how we want the model to represent the raw
|
Now that we have the description of how we want the model to represent the raw
|
||||||
features, we can build the estimator.
|
features, we can build the estimator.
|
||||||
@ -305,8 +311,7 @@ provides several pre-made classifier Estimators, including:
|
|||||||
* @{tf.estimator.DNNClassifier}—for deep models that perform multi-class
|
* @{tf.estimator.DNNClassifier}—for deep models that perform multi-class
|
||||||
classification.
|
classification.
|
||||||
* @{tf.estimator.DNNLinearCombinedClassifier}—for wide-n-deep models.
|
* @{tf.estimator.DNNLinearCombinedClassifier}—for wide-n-deep models.
|
||||||
* @{tf.estimator.LinearClassifier}—for linear models that feed results into
|
* @{tf.estimator.LinearClassifier}— for classifiers based on linear models.
|
||||||
binary classifiers.
|
|
||||||
|
|
||||||
For the Iris problem, `tf.estimator.DNNClassifier` seems like the best choice.
|
For the Iris problem, `tf.estimator.DNNClassifier` seems like the best choice.
|
||||||
Here's how we instantiated this Estimator:
|
Here's how we instantiated this Estimator:
|
||||||
@ -336,14 +341,15 @@ Train the model by calling the Estimator's `train` method as follows:
|
|||||||
```python
|
```python
|
||||||
# Train the Model.
|
# Train the Model.
|
||||||
classifier.train(
|
classifier.train(
|
||||||
input_fn=lambda:train_input_fn(train_x, train_y, args.batch_size),
|
input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
|
||||||
steps=args.train_steps)
|
steps=args.train_steps)
|
||||||
```
|
```
|
||||||
|
|
||||||
Here we wrap up our `input_fn` call in a [`lambda`](https://docs.python.org/3/tutorial/controlflow.html)
|
Here we wrap up our `input_fn` call in a
|
||||||
to allow the Estimator to call it, at the correct time, with no arguments.
|
[`lambda`](https://docs.python.org/3/tutorial/controlflow.html)
|
||||||
The `steps` argument tells the method to stop training after a number of
|
to capture the arguments while providing an input function that takes no
|
||||||
training steps.
|
arguments, as expected by the Estimator. The `steps` argument tells the method
|
||||||
|
to stop training after a number of training steps.
|
||||||
|
|
||||||
### Evaluate the trained model
|
### Evaluate the trained model
|
||||||
|
|
||||||
@ -354,14 +360,14 @@ model on the test data:
|
|||||||
```python
|
```python
|
||||||
# Evaluate the model.
|
# Evaluate the model.
|
||||||
eval_result = classifier.evaluate(
|
eval_result = classifier.evaluate(
|
||||||
input_fn=lambda:eval_input_fn(test_x, test_y, args.batch_size))
|
input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
|
||||||
|
|
||||||
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
|
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
|
||||||
```
|
```
|
||||||
|
|
||||||
Note how unlike our call to the `train` method, we did not pass the `steps`
|
Unlike our call to the `train` method, we did not pass the `steps`
|
||||||
argument to evaluate. Our `eval_input_fn` doesn't use the `repeat` method on
|
argument to evaluate. Our `eval_input_fn` only yields a single
|
||||||
the dataset, so evaluation just runs to the end of the data.
|
[epoch](https://developers.google.com/machine-learning/glossary/#epoch) of data.
|
||||||
|
|
||||||
Running this code yields the following output (or something similar):
|
Running this code yields the following output (or something similar):
|
||||||
|
|
||||||
@ -387,7 +393,8 @@ predict_x = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
predictions = classifier.predict(
|
predictions = classifier.predict(
|
||||||
input_fn=lambda:eval_input_fn(predict_x, batch_size=args.batch_size))
|
input_fn=lambda:iris_data.eval_input_fn(predict_x,
|
||||||
|
batch_size=args.batch_size))
|
||||||
```
|
```
|
||||||
|
|
||||||
The `predict` method returns a Python iterable, yielding a dictionary of
|
The `predict` method returns a Python iterable, yielding a dictionary of
|
||||||
@ -401,29 +408,35 @@ for pred_dict, expec in zip(predictions, expected):
|
|||||||
|
|
||||||
class_id = pred_dict['class_ids'][0]
|
class_id = pred_dict['class_ids'][0]
|
||||||
probability = pred_dict['probabilities'][class_id]
|
probability = pred_dict['probabilities'][class_id]
|
||||||
print(template.format(SPECIES[class_id], 100 * probability, expec))
|
|
||||||
|
print(template.format(iris_data.SPECIES[class_id],
|
||||||
|
100 * probability, expec))
|
||||||
```
|
```
|
||||||
|
|
||||||
Running the preceding code yields the following output:
|
Running the preceding code yields the following output:
|
||||||
|
|
||||||
``` None
|
``` None
|
||||||
...
|
...
|
||||||
Prediction is "Sentosa" (99.6%), expected "Setosa"
|
Prediction is "Setosa" (99.6%), expected "Setosa"
|
||||||
|
|
||||||
Prediction is "Versicolor" (99.8%), expected "Versicolor"
|
Prediction is "Versicolor" (99.8%), expected "Versicolor"
|
||||||
|
|
||||||
Prediction is "Virginica" (97.9%), expected "Virginica"
|
Prediction is "Virginica" (97.9%), expected "Virginica"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Next
|
|
||||||
|
|
||||||
Now that you've gotten started writing TensorFlow programs.
|
## Summary
|
||||||
|
|
||||||
* For more on Datasets, see the
|
Pre-made Estimators are an effective way to quickly create standard models.
|
||||||
@{$programmers_guide/datasets$Programmer's guide} and
|
|
||||||
@{tf.data$reference documentation}.
|
Now that you've gotten started writing TensorFlow programs, consider the
|
||||||
* For more on Estimators, see the
|
following material:
|
||||||
@{$programmers_guide/estimators$Programmer's guide} and
|
|
||||||
@{tf.estimator$reference documentation}.
|
* @{$get_started/saving_models$Checkpoints} to learn how to save and restore
|
||||||
<!--TODO(markdaoust) add links to next get_started section when it exists.-->
|
models.
|
||||||
|
* @{$get_started/datasets_quickstart$Datasets} to learn more about importing
|
||||||
|
data into your
|
||||||
|
model.
|
||||||
|
* @{$get_started/custom_estimators$Creating Custom Estimators} to learn how to
|
||||||
|
write your own Estimator, customized for a particular problem.
|
||||||
|
|
||||||
|
@ -15,9 +15,8 @@ This document focuses on checkpoints. For details on SavedModel, see the
|
|||||||
|
|
||||||
## Sample code
|
## Sample code
|
||||||
|
|
||||||
This document relies on the same Iris classification example detailed in
|
This document relies on the same
|
||||||
<!-- TODO (barryr): fill in link when module settles down. -->
|
[https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py](Iris classification example) detailed in @{$premade_estimators$Getting Started with TensorFlow}.
|
||||||
@{$premade_estimators$Getting Started with TensorFlow}.
|
|
||||||
To download and access the example, invoke the following two commands:
|
To download and access the example, invoke the following two commands:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@ -228,10 +227,12 @@ This separation will keep your checkpoints recoverable.
|
|||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
Checkpoints provide an easy automatic mechanism for storing and restoring
|
Checkpoints provide an easy automatic mechanism for saving and restoring
|
||||||
models created by Estimators. See the @{$saved_model$Saving and Restoring}
|
models created by Estimators.
|
||||||
|
|
||||||
|
See the @{$saved_model$Saving and Restoring}
|
||||||
chapter of the *TensorFlow Programmer's Guide* for details on:
|
chapter of the *TensorFlow Programmer's Guide* for details on:
|
||||||
|
|
||||||
* Saving and restoring models created by low-level TensorFlow APIs.
|
* Saving and restoring models using low-level TensorFlow APIs.
|
||||||
* Saving and restoring models in the SavedModel format, which is a
|
* Exporting and importing models in the SavedModel format, which is a
|
||||||
language-neutral, recoverable, serialization format.
|
language-neutral, recoverable, serialization format.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user