Add profile logging to Tpu Estimator.

PiperOrigin-RevId: 240469608
This commit is contained in:
A. Unique TensorFlower 2019-03-26 18:12:07 -07:00 committed by TensorFlower Gardener
parent 3d2488f052
commit a26413ef0a
3 changed files with 94 additions and 1 deletions

View File

@ -125,6 +125,7 @@ py_library(
"__init__.py",
"bfloat16.py",
"device_assignment.py",
"profile_logger.py",
"session_support.py",
"tensor_tracer.py",
"topology.py",

View File

@ -0,0 +1,69 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========================================================================
"""A logger for profiling events."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary.writer import writer
class ProfileLogger(object):
"""For logging profiling events."""
def _set_summary_dir(self, model_dir):
"""Sets the summary directory to be model_dir."""
if model_dir is None:
self._summary_dir = None
self._summary_writer = None
logging.warning('profile_logger: model_dir is None.'
'So nowhere to write summaries')
return
self._summary_dir = os.path.join(model_dir, 'profile')
try:
self._summary_writer = writer.FileWriter(
logdir=self._summary_dir, filename_suffix='.profile_logger')
logging.info('profile_logger(): set the summary directory to %s',
self._summary_dir)
except Exception: # pylint: disable=broad-except
logging.warning('profile_logger(): failed to create %s',
self._summary_dir)
self._summary_dir = None
self._summary_writer = None
def __init__(self, model_dir):
self._set_summary_dir(model_dir)
def log_event(self, event, phase):
"""Logs the given event to the summary directory."""
event_name = 'profile/' + event + '_' + phase
if self._summary_writer is None:
logging.warning('profile_logger: cannot log event "%s" '
' because of no summary directory', event_name)
return
# For now, we only need the event timestamp. No need to pass any value.
s = Summary(value=[Summary.Value(tag=event_name,
simple_value=0.0)])
self._summary_writer.add_summary(s)
self._summary_writer.flush()
logging.info('profile_logger: log event "%s"', event_name)

View File

@ -62,6 +62,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.tpu import _tpu_estimator_embedding
from tensorflow.python.tpu import error_handling
from tensorflow.python.tpu import functional as tpu_functional
from tensorflow.python.tpu import profile_logger
from tensorflow.python.tpu import session_support
from tensorflow.python.tpu import tensor_tracer
from tensorflow.python.tpu import tpu
@ -451,6 +452,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
enqueue_ops,
dequeue_ops,
tpu_compile_op,
prof_logger,
run_infeed_loop_on_coordinator=True,
rendezvous=None,
master=None,
@ -478,6 +480,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
# initialization.
self._should_initialize_tpu = not ctx.model_parallelism_enabled
self._tpu_compile_op = tpu_compile_op
self._profile_logger = prof_logger
def begin(self):
logging.info('TPU job name %s', self._master_job)
@ -540,6 +543,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
if self._should_initialize_tpu:
logging.info('Init TPU system')
start = time.time()
self._profile_logger.log_event('init_system', 'begin')
with ops.Graph().as_default():
with tf_session.Session(
self._master, config=self._session_config) as sess:
@ -547,6 +551,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
tpu.initialize_system(
job=self._master_job,
embedding_config=self._embedding_layer_config))
self._profile_logger.log_event('init_system', 'end')
logging.info('Initialized TPU in %d seconds', time.time() - start)
session.run(self._init_ops,
@ -593,13 +598,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook):
def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op,
def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op, prof_logger,
rendezvous=None, master=None, session_config=None):
super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__(
ctx,
enqueue_ops,
dequeue_ops,
tpu_compile_op=tpu_compile_op,
prof_logger=prof_logger,
run_infeed_loop_on_coordinator=False,
rendezvous=rendezvous,
master=master,
@ -2382,6 +2388,7 @@ class TPUEstimator(estimator_lib.Estimator):
self._is_input_fn_invoked = None
self._rendezvous = {}
self._profile_logger = profile_logger.ProfileLogger(self.model_dir)
def _add_meta_graph_for_mode(self,
builder,
@ -2711,6 +2718,7 @@ class TPUEstimator(estimator_lib.Estimator):
rendezvous = error_handling.ErrorRendezvous(num_sources=3)
self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous
try:
self._profile_logger.log_event('train', 'begin')
return super(TPUEstimator, self).train(
input_fn=input_fn,
hooks=hooks,
@ -2720,6 +2728,7 @@ class TPUEstimator(estimator_lib.Estimator):
except Exception: # pylint: disable=broad-except
rendezvous.record_error('training_loop', sys.exc_info())
finally:
self._profile_logger.log_event('train', 'end')
rendezvous.record_done('training_loop')
rendezvous.raise_errors()
@ -2732,6 +2741,7 @@ class TPUEstimator(estimator_lib.Estimator):
rendezvous = error_handling.ErrorRendezvous(num_sources=3)
self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous
try:
self._profile_logger.log_event('eval', 'begin')
return super(TPUEstimator, self).evaluate(
input_fn,
steps=steps,
@ -2741,6 +2751,7 @@ class TPUEstimator(estimator_lib.Estimator):
except Exception: # pylint: disable=broad-except
rendezvous.record_error('evaluation_loop', sys.exc_info())
finally:
self._profile_logger.log_event('eval', 'end')
rendezvous.record_done('evaluation_loop')
rendezvous.raise_errors()
@ -2753,6 +2764,7 @@ class TPUEstimator(estimator_lib.Estimator):
rendezvous = error_handling.ErrorRendezvous(num_sources=3)
self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous
try:
self._profile_logger.log_event('predict', 'begin')
for result in super(TPUEstimator, self).predict(
input_fn=input_fn,
predict_keys=predict_keys,
@ -2763,6 +2775,7 @@ class TPUEstimator(estimator_lib.Estimator):
except Exception: # pylint: disable=broad-except
rendezvous.record_error('prediction_loop', sys.exc_info())
finally:
self._profile_logger.log_event('predict', 'end')
rendezvous.record_done('prediction_loop')
rendezvous.raise_errors()
@ -2775,6 +2788,7 @@ class TPUEstimator(estimator_lib.Estimator):
def _model_fn(features, labels, mode, config, params):
"""A Estimator `model_fn` for TPUEstimator."""
self._profile_logger.log_event('model_fn', 'begin')
# `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
# but not in `export_savedmodel()`.
if self._is_input_fn_invoked:
@ -2814,6 +2828,7 @@ class TPUEstimator(estimator_lib.Estimator):
if self._log_every_n_steps is not None:
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks + (examples_hook,))
self._profile_logger.log_event('model_fn', 'end')
return estimator_spec
assert labels is None, '`labels` passed to `model_fn` must be `None`.'
@ -2830,10 +2845,12 @@ class TPUEstimator(estimator_lib.Estimator):
tpu_init_ops.append(dummy_table_variables_init)
input_holders = _InputPipeline(input_fn, batch_axis, ctx)
self._profile_logger.log_event('setup_infeed', 'begin')
enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
graph = ops.get_default_graph()
self._profile_logger.log_event('setup_infeed', 'end')
for enqueue_op in enqueue_ops:
if isinstance(enqueue_op, list):
graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
@ -2897,6 +2914,7 @@ class TPUEstimator(estimator_lib.Estimator):
enqueue_ops,
host_ops,
tpu_compile_op=compile_op,
prof_logger=self._profile_logger,
run_infeed_loop_on_coordinator=(
run_infeed_loop_on_coordinator),
rendezvous=self._rendezvous[mode],
@ -2947,6 +2965,7 @@ class TPUEstimator(estimator_lib.Estimator):
train_op = control_flow_ops.group(*update_ops)
graph.add_to_collection(_TPU_TRAIN_OP, train_op)
self._profile_logger.log_event('model_fn', 'end')
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
@ -3022,6 +3041,7 @@ class TPUEstimator(estimator_lib.Estimator):
enqueue_ops,
eval_update_ops + host_ops,
tpu_compile_op=compile_op,
prof_logger=self._profile_logger,
run_infeed_loop_on_coordinator=(
run_infeed_loop_on_coordinator),
rendezvous=self._rendezvous[mode],
@ -3033,6 +3053,7 @@ class TPUEstimator(estimator_lib.Estimator):
if eval_hooks:
hooks.extend(eval_hooks)
self._profile_logger.log_event('model_fn', 'end')
return model_fn_lib.EstimatorSpec(
mode,
loss=mean_loss,
@ -3102,6 +3123,7 @@ class TPUEstimator(estimator_lib.Estimator):
TPUInfeedOutfeedSessionHookForPrediction(
ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode],
tpu_compile_op=compile_op,
prof_logger=self._profile_logger,
master=self._config.master,
session_config=self._session_config),
] + input_hooks
@ -3109,6 +3131,7 @@ class TPUEstimator(estimator_lib.Estimator):
if prediction_hooks:
hooks.extend(prediction_hooks)
self._profile_logger.log_event('model_fn', 'end')
return model_fn_lib.EstimatorSpec(
mode,
prediction_hooks=hooks,