252 lines
9.1 KiB
Python
252 lines
9.1 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tf.training.evaluation."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import random_seed
|
|
from tensorflow.python.layers import layers
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import metrics as metrics_module
|
|
from tensorflow.python.ops import state_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops.losses import losses
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.training import basic_session_run_hooks
|
|
from tensorflow.python.training import evaluation
|
|
from tensorflow.python.training import gradient_descent
|
|
from tensorflow.python.training import monitored_session
|
|
from tensorflow.python.training import saver
|
|
from tensorflow.python.training import training
|
|
|
|
_USE_GLOBAL_STEP = 0
|
|
|
|
|
|
def logistic_classifier(inputs):
|
|
return layers.dense(inputs, 1, activation=math_ops.sigmoid)
|
|
|
|
|
|
def local_variable(init_value, name):
|
|
return variable_scope.get_variable(
|
|
name,
|
|
dtype=dtypes.float32,
|
|
initializer=init_value,
|
|
trainable=False,
|
|
collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
|
|
|
|
|
class EvaluateOnceTest(test.TestCase):
|
|
|
|
def setUp(self):
|
|
super(EvaluateOnceTest, self).setUp()
|
|
|
|
# Create an easy training set:
|
|
np.random.seed(0)
|
|
|
|
self._inputs = np.zeros((16, 4))
|
|
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
|
|
|
for i in range(16):
|
|
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
|
self._inputs[i, j] = 1
|
|
|
|
def _train_model(self, checkpoint_dir, num_steps):
|
|
"""Trains a simple classification model.
|
|
|
|
Note that the data has been configured such that after around 300 steps,
|
|
the model has memorized the dataset (e.g. we can expect %100 accuracy).
|
|
|
|
Args:
|
|
checkpoint_dir: The directory where the checkpoint is written to.
|
|
num_steps: The number of steps to train for.
|
|
"""
|
|
with ops.Graph().as_default():
|
|
random_seed.set_random_seed(0)
|
|
tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
|
|
tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
|
|
|
|
tf_predictions = logistic_classifier(tf_inputs)
|
|
loss_op = losses.log_loss(labels=tf_labels, predictions=tf_predictions)
|
|
|
|
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
|
|
train_op = optimizer.minimize(loss_op,
|
|
training.get_or_create_global_step())
|
|
|
|
with monitored_session.MonitoredTrainingSession(
|
|
checkpoint_dir=checkpoint_dir,
|
|
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)]) as session:
|
|
loss = None
|
|
while not session.should_stop():
|
|
_, loss = session.run([train_op, loss_op])
|
|
|
|
if num_steps >= 300:
|
|
assert loss < .015
|
|
|
|
def testEvaluatePerfectModel(self):
|
|
checkpoint_dir = os.path.join(self.get_temp_dir(),
|
|
'evaluate_perfect_model_once')
|
|
|
|
# Train a Model to completion:
|
|
self._train_model(checkpoint_dir, num_steps=300)
|
|
|
|
# Run
|
|
inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
|
|
labels = constant_op.constant(self._labels, dtype=dtypes.float32)
|
|
logits = logistic_classifier(inputs)
|
|
predictions = math_ops.round(logits)
|
|
|
|
accuracy, update_op = metrics_module.accuracy(labels, predictions)
|
|
|
|
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
|
|
|
|
final_ops_values = evaluation._evaluate_once(
|
|
checkpoint_path=checkpoint_path,
|
|
eval_ops=update_op,
|
|
final_ops={'accuracy': (accuracy, update_op)},
|
|
hooks=[
|
|
evaluation._StopAfterNEvalsHook(1),
|
|
])
|
|
self.assertGreater(final_ops_values['accuracy'], .99)
|
|
|
|
def testEvaluateWithFiniteInputs(self):
|
|
checkpoint_dir = os.path.join(self.get_temp_dir(),
|
|
'evaluate_with_finite_inputs')
|
|
|
|
# Train a Model to completion:
|
|
self._train_model(checkpoint_dir, num_steps=300)
|
|
|
|
# Run evaluation. Inputs are fed through input producer for one epoch.
|
|
all_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
|
|
all_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
|
|
|
|
single_input, single_label = training.slice_input_producer(
|
|
[all_inputs, all_labels], num_epochs=1)
|
|
inputs, labels = training.batch([single_input, single_label], batch_size=6,
|
|
allow_smaller_final_batch=True)
|
|
|
|
logits = logistic_classifier(inputs)
|
|
predictions = math_ops.round(logits)
|
|
|
|
accuracy, update_op = metrics_module.accuracy(labels, predictions)
|
|
|
|
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
|
|
|
|
final_ops_values = evaluation._evaluate_once(
|
|
checkpoint_path=checkpoint_path,
|
|
eval_ops=update_op,
|
|
final_ops={
|
|
'accuracy': (accuracy, update_op),
|
|
'eval_steps': evaluation._get_or_create_eval_step()
|
|
},
|
|
hooks=[
|
|
evaluation._StopAfterNEvalsHook(None),
|
|
])
|
|
self.assertTrue(final_ops_values['accuracy'] > .99)
|
|
# Runs evaluation for 4 iterations. First 2 evaluate full batch of 6 inputs
|
|
# each; the 3rd iter evaluates the remaining 4 inputs, and the last one
|
|
# triggers an error which stops evaluation.
|
|
self.assertEqual(final_ops_values['eval_steps'], 4)
|
|
|
|
def testEvalOpAndFinalOp(self):
|
|
checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')
|
|
|
|
# Train a model for a single step to get a checkpoint.
|
|
self._train_model(checkpoint_dir, num_steps=1)
|
|
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
|
|
|
|
# Create the model so we have something to restore.
|
|
inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
|
|
logistic_classifier(inputs)
|
|
|
|
num_evals = 5
|
|
final_increment = 9.0
|
|
|
|
my_var = local_variable(0.0, name='MyVar')
|
|
eval_ops = state_ops.assign_add(my_var, 1.0)
|
|
final_ops = array_ops.identity(my_var) + final_increment
|
|
|
|
final_hooks = [evaluation._StopAfterNEvalsHook(num_evals),]
|
|
initial_hooks = list(final_hooks)
|
|
final_ops_values = evaluation._evaluate_once(
|
|
checkpoint_path=checkpoint_path,
|
|
eval_ops=eval_ops,
|
|
final_ops={'value': final_ops},
|
|
hooks=final_hooks)
|
|
self.assertEqual(final_ops_values['value'], num_evals + final_increment)
|
|
self.assertEqual(initial_hooks, final_hooks)
|
|
|
|
def testMultiEvalStepIncrements(self):
|
|
checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')
|
|
|
|
# Train a model for a single step to get a checkpoint.
|
|
self._train_model(checkpoint_dir, num_steps=1)
|
|
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
|
|
|
|
# Create the model so we have something to restore.
|
|
inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
|
|
logistic_classifier(inputs)
|
|
|
|
num_evals = 6
|
|
|
|
my_var = local_variable(0.0, name='MyVar')
|
|
# In eval ops, we also increase the eval step one more time.
|
|
eval_ops = [state_ops.assign_add(my_var, 1.0),
|
|
state_ops.assign_add(
|
|
evaluation._get_or_create_eval_step(), 1, use_locking=True)]
|
|
expect_eval_update_counts = num_evals // 2
|
|
|
|
final_ops = array_ops.identity(my_var)
|
|
|
|
final_ops_values = evaluation._evaluate_once(
|
|
checkpoint_path=checkpoint_path,
|
|
eval_ops=eval_ops,
|
|
final_ops={'value': final_ops},
|
|
hooks=[evaluation._StopAfterNEvalsHook(num_evals),])
|
|
self.assertEqual(final_ops_values['value'], expect_eval_update_counts)
|
|
|
|
def testOnlyFinalOp(self):
|
|
checkpoint_dir = os.path.join(self.get_temp_dir(), 'only_final_ops')
|
|
|
|
# Train a model for a single step to get a checkpoint.
|
|
self._train_model(checkpoint_dir, num_steps=1)
|
|
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
|
|
|
|
# Create the model so we have something to restore.
|
|
inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
|
|
logistic_classifier(inputs)
|
|
|
|
final_increment = 9.0
|
|
|
|
my_var = local_variable(0.0, name='MyVar')
|
|
final_ops = array_ops.identity(my_var) + final_increment
|
|
|
|
final_ops_values = evaluation._evaluate_once(
|
|
checkpoint_path=checkpoint_path, final_ops={'value': final_ops})
|
|
self.assertEqual(final_ops_values['value'], final_increment)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test.main()
|