Add TraceMe in Keras.

PiperOrigin-RevId: 275367420
Change-Id: I422a80815bf20304f302989bfcfe8c150afe49c4
This commit is contained in:
A. Unique TensorFlower 2019-10-17 16:52:16 -07:00 committed by TensorFlower Gardener
parent 01265efe05
commit cc951f765c
2 changed files with 15 additions and 11 deletions
tensorflow/python/keras

View File

@ -213,6 +213,7 @@ py_library(
"//tensorflow/python/module",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/ops/ragged:ragged_util",
"//tensorflow/python/profiler:traceme",
"//tensorflow/python/training/tracking:data_structures",
"//tensorflow/tools/docs:doc_controls",
"@six_archive//:six",

View File

@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine import training_v2_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import traceme
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@ -744,14 +745,16 @@ class TrainingContext(object):
@tf_contextlib.contextmanager
def on_batch(self, step=0, mode=ModeKeys.TRAIN, size=1):
"""Provide a scope for running one batch."""
batch_logs = {'batch': step, 'size': size}
self.callbacks._call_batch_hook(
mode, 'begin', step, batch_logs)
self.progbar.on_batch_begin(step, batch_logs)
try:
yield batch_logs
finally:
if not batch_logs.pop('data_exhausted', False):
self.callbacks._call_batch_hook(
mode, 'end', step, batch_logs)
self.progbar.on_batch_end(step, batch_logs)
with traceme.TraceMe(
'TraceContext', graph_type=mode, step_num=step, batch_size=size):
batch_logs = {'batch': step, 'size': size}
self.callbacks._call_batch_hook(
mode, 'begin', step, batch_logs)
self.progbar.on_batch_begin(step, batch_logs)
try:
yield batch_logs
finally:
if not batch_logs.pop('data_exhausted', False):
self.callbacks._call_batch_hook(
mode, 'end', step, batch_logs)
self.progbar.on_batch_end(step, batch_logs)