STT-tensorflow/tensorflow/python/tpu/async_checkpoint_test.py
A. Unique TensorFlower 324f1adc0e Add test for AsyncCheckpointSaverHook without listeners.
PiperOrigin-RevId: 286242467
Change-Id: I9b023f6e979a18978d3cc10cca2706134bb8611e
2019-12-18 12:29:43 -08:00

197 lines
7.2 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Test async checkpointing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import flags
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import async_checkpoint
from tensorflow.python.tpu import tpu_config
from tensorflow.python.tpu import tpu_estimator
from tensorflow.python.tpu import tpu_optimizer
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import training
from tensorflow_estimator.python.estimator import estimator as estimator_lib
from tensorflow_estimator.python.estimator import model_fn as model_fn_lib
FLAGS = flags.FLAGS
flags.DEFINE_string('tpu', '', 'TPU to use in this test.')
flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
flags.DEFINE_string(
'model_dir',
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'),
'GCS path to store model and checkpoints.')
def input_fn(params):
"""Return a dataset of source and target sequences for training."""
return (constant_op.constant(
np.random.randn(params['batch_size'], 1000), dtype=dtypes.float32),
constant_op.constant(
np.random.randint(0, 10, params['batch_size']),
dtype=dtypes.int32))
def model_fn(features, labels, mode, params):
del params # unused
with variable_scope.variable_scope('m', reuse=variable_scope.AUTO_REUSE):
w = variable_scope.get_variable('W', shape=[1000, 10])
logits = math_ops.matmul(features, w)
loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
if mode == model_fn_lib.ModeKeys.TRAIN:
optimizer = training.RMSPropOptimizer(learning_rate=0.01)
optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
train_op = optimizer.minimize(loss, training.get_global_step())
return tpu_estimator.TPUEstimatorSpec(
mode=model_fn_lib.ModeKeys.TRAIN,
loss=loss,
train_op=train_op,
)
elif mode == model_fn_lib.ModeKeys.EVAL:
def metric_fn(labels, logits):
labels = math_ops.cast(labels, dtypes.int64)
logging.info('LABELS %s %s', labels, logits)
return {
'recall@1': metrics_lib.recall_at_k(labels, logits, 1),
'recall@5': metrics_lib.recall_at_k(labels, logits, 5),
}
loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
eval_metrics = (metric_fn, [labels, logits])
return tpu_estimator.TPUEstimatorSpec(
mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metrics=eval_metrics)
class AsyncCheckpointingTest(test.TestCase):
def testAsyncCheckpointHookEnabled(self):
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
checkpoint_interval = 5
config = tpu_config.RunConfig(
master=resolver.master(),
model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
save_checkpoints_steps=1000,
keep_checkpoint_max=11, # off by one
tpu_config=tpu_config.TPUConfig(
iterations_per_loop=checkpoint_interval,))
estimator = tpu_estimator.TPUEstimator(
use_tpu=True,
model_fn=model_fn,
config=config,
train_batch_size=32,
eval_batch_size=32,
predict_batch_size=1,
params={},
)
i = 10
mock_listener = test.mock.create_autospec(
basic_session_run_hooks.CheckpointSaverListener)
estimator.train(
input_fn=input_fn,
max_steps=i * 10,
hooks=[
async_checkpoint.AsyncCheckpointSaverHook(
FLAGS.model_dir,
save_steps=checkpoint_interval,
listeners=[mock_listener])
])
current_step = estimator_lib._load_global_step_from_checkpoint_dir(
FLAGS.model_dir) # pylint: disable=protected-access
# TODO(power) -- identify a better way to count the number of checkpoints.
checkpoints = file_io.get_matching_files(
FLAGS.model_dir + '/model.ckpt*.meta')
checkpoint_count = len(checkpoints)
logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
self.assertLessEqual(checkpoint_count, 10)
self.assertEqual(current_step, i * 10)
mock_listener.before_save.assert_called()
mock_listener.after_save.assert_called()
def testAsyncCheckpointHookWithoutListeners(self):
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
checkpoint_interval = 5
keep_checkpoint_max = 10
config = tpu_config.RunConfig(
master=resolver.master(),
model_dir=os.path.join(FLAGS.model_dir, 'runconfig'),
save_checkpoints_steps=1000,
keep_checkpoint_max=keep_checkpoint_max+1, # off by one
tpu_config=tpu_config.TPUConfig(
iterations_per_loop=checkpoint_interval,))
estimator = tpu_estimator.TPUEstimator(
use_tpu=True,
model_fn=model_fn,
config=config,
train_batch_size=32,
eval_batch_size=32,
predict_batch_size=1,
params={},
)
max_steps = 100
estimator.train(
input_fn=input_fn,
max_steps=max_steps,
hooks=[
async_checkpoint.AsyncCheckpointSaverHook(
FLAGS.model_dir,
save_steps=checkpoint_interval)
])
current_step = estimator_lib._load_global_step_from_checkpoint_dir(
FLAGS.model_dir) # pylint: disable=protected-access
# TODO(power) -- identify a better way to count the number of checkpoints.
checkpoints = file_io.get_matching_files(
FLAGS.model_dir + '/model.ckpt*.meta')
checkpoint_count = len(checkpoints)
logging.info('Found %d checkpoints: %s', checkpoint_count, checkpoints)
self.assertLessEqual(checkpoint_count, keep_checkpoint_max)
self.assertEqual(current_step, max_steps)
if __name__ == '__main__':
v2_compat.disable_v2_behavior()
test.main()