Add TraceMe in Keras.
PiperOrigin-RevId: 275367420 Change-Id: I422a80815bf20304f302989bfcfe8c150afe49c4
This commit is contained in:
parent
01265efe05
commit
cc951f765c
tensorflow/python/keras
@ -213,6 +213,7 @@ py_library(
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:ragged_util",
|
||||
"//tensorflow/python/profiler:traceme",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
"@six_archive//:six",
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.engine import training_v2_utils
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import traceme
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
@ -744,14 +745,16 @@ class TrainingContext(object):
|
||||
@tf_contextlib.contextmanager
|
||||
def on_batch(self, step=0, mode=ModeKeys.TRAIN, size=1):
|
||||
"""Provide a scope for running one batch."""
|
||||
batch_logs = {'batch': step, 'size': size}
|
||||
self.callbacks._call_batch_hook(
|
||||
mode, 'begin', step, batch_logs)
|
||||
self.progbar.on_batch_begin(step, batch_logs)
|
||||
try:
|
||||
yield batch_logs
|
||||
finally:
|
||||
if not batch_logs.pop('data_exhausted', False):
|
||||
self.callbacks._call_batch_hook(
|
||||
mode, 'end', step, batch_logs)
|
||||
self.progbar.on_batch_end(step, batch_logs)
|
||||
with traceme.TraceMe(
|
||||
'TraceContext', graph_type=mode, step_num=step, batch_size=size):
|
||||
batch_logs = {'batch': step, 'size': size}
|
||||
self.callbacks._call_batch_hook(
|
||||
mode, 'begin', step, batch_logs)
|
||||
self.progbar.on_batch_begin(step, batch_logs)
|
||||
try:
|
||||
yield batch_logs
|
||||
finally:
|
||||
if not batch_logs.pop('data_exhausted', False):
|
||||
self.callbacks._call_batch_hook(
|
||||
mode, 'end', step, batch_logs)
|
||||
self.progbar.on_batch_end(step, batch_logs)
|
||||
|
Loading…
Reference in New Issue
Block a user