Tensor Tracer: enable summary mode for models using low level APIs.

PiperOrigin-RevId: 339513418
Change-Id: I7198c19c68e79b5269c90985a4f299919730548b
This commit is contained in:
Mehmet Deveci 2020-10-28 12:23:12 -07:00 committed by TensorFlower Gardener
parent 018ad69342
commit 26bacff499
2 changed files with 36 additions and 1 deletions

View File

@ -82,6 +82,7 @@ _REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
_OUTPUT_STREAM_ESCAPE = 'file://'
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers'
_TRACE_FILE_NAME = 'trace.all'
_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
@ -178,6 +179,9 @@ def set_parameters(tensor_tracer_params=None):
traced. included_optypes can be set as a regular expression. E.g,
'--included_optypes=some_op_type --excluded_optypes=*.' will trace
only the ops with type 'some_op_type'
- flush_summaries: If summary mode is used, flush_summaries=1 will
flush summaries using outside compilation. Note that, if used with
low level APIs, flush_summaries=1 is necessary to obtain results.
Advanced Flags:
- trace_scalar: Scalar values are not traced by default. If this flag is
set, scalar values will also be traced.
@ -1662,6 +1666,8 @@ class TensorTracer(object):
self._parameters.trace_dir,
filename_suffix=file_suffix,
max_queue=_TT_SUMMARY_MAX_QUEUE)
ops.get_default_graph().add_to_collection(
TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer)
with summary_writer.as_default():
summary_metadata = summary_pb2.SummaryMetadata(
plugin_data=summary_pb2.SummaryMetadata.PluginData(
@ -1891,6 +1897,32 @@ class TensorTracer(object):
processed_t_fetches = control_flow_ops.tuple(
processed_t_fetches, control_inputs=[cache_write_op])
del self._host_call_fn[_TT_HOSTCALL_KEY]
elif self._parameters.flush_summaries_with_outside_compile:
write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write):
step = caches_to_write['step']
tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG]
tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0])
if not self._parameters.collect_summary_per_core:
tt_core_summary = self.aggregate_global_cache(tt_core_summary)
def write_if_core_0(step, replica_id, tt_summary):
return control_flow_ops.cond(
math_ops.equal(replica_id, 0),
lambda: write_cache(step=step, event_file_suffix=None, # pylint: disable=g-long-lambda
tensor_tracer_summary=tt_summary),
control_flow_ops.no_op)
write_op = tpu.outside_compilation(write_if_core_0, step=step,
replica_id=self._replica_id,
tt_summary=tt_core_summary)
processed_t_fetches = control_flow_ops.tuple(
processed_t_fetches, control_inputs=[write_op])
del self._host_call_fn[_TT_HOSTCALL_KEY]
else:
raise ValueError('Outside compiled flush in only supported for '
'summary mode')
else:
processed_t_fetches = self._flush_tensor_values_cache(
processed_t_fetches, op_fetches, on_tpu=on_tpu,

View File

@ -71,6 +71,7 @@ FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core'
FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache'
FLAG_NAME_INSPECT_TRACE = 'inspect_trace'
FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory'
FLAG_FLUSH_SUMMARY = 'flush_summaries'
_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
@ -138,6 +139,8 @@ class TTParameters(object):
_TT_DEFAULT_TRACE_LEVEL)
self.summary_signatures = self._get_summary_signatures()
self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE)
self.flush_summaries_with_outside_compile = self.is_flag_on(
FLAG_FLUSH_SUMMARY)
def _get_report_filepath(self):
"""Sets the path of the output report file."""
@ -256,7 +259,7 @@ class TTParameters(object):
FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL,
FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE,
FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR,
FLAG_NAME_INSPECT_TRACE
FLAG_NAME_INSPECT_TRACE, FLAG_FLUSH_SUMMARY
]
tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR)
if not tensor_tracer_flags: