Tensor Tracer: enable summary mode for models using low level APIs.
PiperOrigin-RevId: 339513418 Change-Id: I7198c19c68e79b5269c90985a4f299919730548b
This commit is contained in:
parent
018ad69342
commit
26bacff499
tensorflow/python/tpu
@ -82,6 +82,7 @@ _REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
|
||||
|
||||
_OUTPUT_STREAM_ESCAPE = 'file://'
|
||||
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
|
||||
TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers'
|
||||
_TRACE_FILE_NAME = 'trace.all'
|
||||
_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
|
||||
_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
|
||||
@ -178,6 +179,9 @@ def set_parameters(tensor_tracer_params=None):
|
||||
traced. included_optypes can be set as a regular expression. E.g,
|
||||
'--included_optypes=some_op_type --excluded_optypes=*.' will trace
|
||||
only the ops with type 'some_op_type'
|
||||
- flush_summaries: If summary mode is used, flush_summaries=1 will
|
||||
flush summaries using outside compilation. Note that, if used with
|
||||
low level APIs, flush_summaries=1 is necessary to obtain results.
|
||||
Advanced Flags:
|
||||
- trace_scalar: Scalar values are not traced by default. If this flag is
|
||||
set, scalar values will also be traced.
|
||||
@ -1662,6 +1666,8 @@ class TensorTracer(object):
|
||||
self._parameters.trace_dir,
|
||||
filename_suffix=file_suffix,
|
||||
max_queue=_TT_SUMMARY_MAX_QUEUE)
|
||||
ops.get_default_graph().add_to_collection(
|
||||
TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer)
|
||||
with summary_writer.as_default():
|
||||
summary_metadata = summary_pb2.SummaryMetadata(
|
||||
plugin_data=summary_pb2.SummaryMetadata.PluginData(
|
||||
@ -1891,6 +1897,32 @@ class TensorTracer(object):
|
||||
processed_t_fetches = control_flow_ops.tuple(
|
||||
processed_t_fetches, control_inputs=[cache_write_op])
|
||||
del self._host_call_fn[_TT_HOSTCALL_KEY]
|
||||
elif self._parameters.flush_summaries_with_outside_compile:
|
||||
write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
|
||||
if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write):
|
||||
step = caches_to_write['step']
|
||||
tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG]
|
||||
tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0])
|
||||
if not self._parameters.collect_summary_per_core:
|
||||
tt_core_summary = self.aggregate_global_cache(tt_core_summary)
|
||||
|
||||
def write_if_core_0(step, replica_id, tt_summary):
|
||||
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(replica_id, 0),
|
||||
lambda: write_cache(step=step, event_file_suffix=None, # pylint: disable=g-long-lambda
|
||||
tensor_tracer_summary=tt_summary),
|
||||
control_flow_ops.no_op)
|
||||
|
||||
write_op = tpu.outside_compilation(write_if_core_0, step=step,
|
||||
replica_id=self._replica_id,
|
||||
tt_summary=tt_core_summary)
|
||||
processed_t_fetches = control_flow_ops.tuple(
|
||||
processed_t_fetches, control_inputs=[write_op])
|
||||
del self._host_call_fn[_TT_HOSTCALL_KEY]
|
||||
else:
|
||||
raise ValueError('Outside compiled flush in only supported for '
|
||||
'summary mode')
|
||||
else:
|
||||
processed_t_fetches = self._flush_tensor_values_cache(
|
||||
processed_t_fetches, op_fetches, on_tpu=on_tpu,
|
||||
|
@ -71,6 +71,7 @@ FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core'
|
||||
FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache'
|
||||
FLAG_NAME_INSPECT_TRACE = 'inspect_trace'
|
||||
FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory'
|
||||
FLAG_FLUSH_SUMMARY = 'flush_summaries'
|
||||
|
||||
_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
|
||||
_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
|
||||
@ -138,6 +139,8 @@ class TTParameters(object):
|
||||
_TT_DEFAULT_TRACE_LEVEL)
|
||||
self.summary_signatures = self._get_summary_signatures()
|
||||
self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE)
|
||||
self.flush_summaries_with_outside_compile = self.is_flag_on(
|
||||
FLAG_FLUSH_SUMMARY)
|
||||
|
||||
def _get_report_filepath(self):
|
||||
"""Sets the path of the output report file."""
|
||||
@ -256,7 +259,7 @@ class TTParameters(object):
|
||||
FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL,
|
||||
FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE,
|
||||
FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR,
|
||||
FLAG_NAME_INSPECT_TRACE
|
||||
FLAG_NAME_INSPECT_TRACE, FLAG_FLUSH_SUMMARY
|
||||
]
|
||||
tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR)
|
||||
if not tensor_tracer_flags:
|
||||
|
Loading…
Reference in New Issue
Block a user