Adding an option to tensor tracer to create a suffix folder based on the fingerprint of the tf.graph.

If use_fingerprint_subdirectory is provided, then the TensorTracer summaries will be written under <trace_dir>/<fingerprint>. If there are changes to the graph, the changes will be listed under different fingerprints.

PiperOrigin-RevId: 311834837
Change-Id: I9dfbabfeb7fbe58a2a47c2581474ed86647781dc
This commit is contained in:
Mehmet Deveci 2020-05-15 18:07:38 -07:00 committed by TensorFlower Gardener
parent 96f1bbe90a
commit cbc4d5442e
4 changed files with 58 additions and 5 deletions

View File

@ -21,6 +21,10 @@ message TensorTracerReport {
// A map from tensor name to its TracedTensorDef. // A map from tensor name to its TracedTensorDef.
map<string, TracedTensorDef> tensordef = 3; map<string, TracedTensorDef> tensordef = 3;
// The fingerprint of the TensorTracerReport (fingerprint calculation excludes
// this field and graphdef).
string fingerprint = 4;
message TensorTracerConfig { message TensorTracerConfig {
// Tensor tracer version, e.g. hostcall, outside compilation. // Tensor tracer version, e.g. hostcall, outside compilation.
string version = 1; string version = 1;

View File

@ -100,7 +100,7 @@ _TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer'
_TT_HOSTCALL_KEY = 'tensor_tracer_host_call' _TT_HOSTCALL_KEY = 'tensor_tracer_host_call'
_TT_EVENT_FILE_SUFFIX = '.tensor_tracer' _TT_EVENT_FILE_SUFFIX = '.tensor_tracer'
_TT_SUMMARY_MAX_QUEUE = 100 _TT_SUMMARY_MAX_QUEUE = 10
def set_parameters(tensor_tracer_params=None): def set_parameters(tensor_tracer_params=None):
@ -206,6 +206,9 @@ def set_parameters(tensor_tracer_params=None):
-> op2 -> op1 -> op0, if op0 has a NaN and trace_stack_size is 1, the -> op2 -> op1 -> op0, if op0 has a NaN and trace_stack_size is 1, the
result of op1 will also be printed. trace_stack_size is 2, the result result of op1 will also be printed. trace_stack_size is 2, the result
of op1 and op2 will be printed. of op1 and op2 will be printed.
- use_fingerprint_subdirectory: The trace directory will be chosen as
using the fingerprint of the trace metadata under the provided
trace_dir.
""" """
flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE
if tensor_tracer_params: if tensor_tracer_params:
@ -547,6 +550,7 @@ class TensorTracer(object):
self._traced_op_names = set() self._traced_op_names = set()
self._report_proto = None self._report_proto = None
self._temp_cache_var = [] self._temp_cache_var = []
self._report_proto_path = ''
def report_proto(self): def report_proto(self):
"""Getter for tensor_tracer.proto object for summary and full_tensor_summary modes. """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes.
@ -564,6 +568,14 @@ class TensorTracer(object):
'Report proto only exists for ' 'Report proto only exists for '
'trace_mode=[summary|full_tensor_summary]') 'trace_mode=[summary|full_tensor_summary]')
def report_proto_path(self):
"""Getter for path where tensor_tracer.proto object should be written.
Returns:
A string path.
"""
return self._report_proto_path
def _get_all_cache_variables(self): def _get_all_cache_variables(self):
return self._cache_variables return self._cache_variables
@ -1366,6 +1378,13 @@ class TensorTracer(object):
self._report_proto = report_handler.create_report_proto( self._report_proto = report_handler.create_report_proto(
self._tt_config, self._parameters, tensor_trace_order, self._tt_config, self._parameters, tensor_trace_order,
tensor_trace_points, self._signature_types()) tensor_trace_points, self._signature_types())
if self._parameters.use_fingerprint_subdir:
self._parameters.trace_dir = os.path.join(
self._parameters.trace_dir, self._report_proto.fingerprint)
logging.info('TensorTracer updating trace_dir to %s',
self._parameters.trace_dir)
self._report_proto_path = tensor_tracer_report.report_proto_path(
self._parameters.trace_dir)
if self._parameters.report_file_path != _SKIP_REPORT_FILE: if self._parameters.report_file_path != _SKIP_REPORT_FILE:
report_handler.write_report_proto(self._report_proto, self._parameters) report_handler.write_report_proto(self._report_proto, self._parameters)
else: else:

View File

@ -74,6 +74,7 @@ FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs'
FLAG_NAME_SUMMARY_SIGNATURES = 'signatures' FLAG_NAME_SUMMARY_SIGNATURES = 'signatures'
FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core'
FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache' FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache'
FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory'
_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' _TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
@ -127,6 +128,7 @@ class TTParameters(object):
self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS) self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS)
self.use_compact_trace = self.is_flag_on(FLAG_NAME_USE_COMPACT_TRACE) self.use_compact_trace = self.is_flag_on(FLAG_NAME_USE_COMPACT_TRACE)
self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR) self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR)
self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR)
# _trace_ops_before_included and _trace_ops_after_included denotes to depth # _trace_ops_before_included and _trace_ops_after_included denotes to depth
# of tracing relative to the ops given in --included_opnames or # of tracing relative to the ops given in --included_opnames or
@ -274,7 +276,7 @@ class TTParameters(object):
FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, FLAG_NAME_OP_RANGE, FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, FLAG_NAME_OP_RANGE,
FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL, FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL,
FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE, FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE,
FLAG_NAME_TEMP_CACHE_VAR FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR
] ]
tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR)
if not tensor_tracer_flags: if not tensor_tracer_flags:

View File

@ -19,8 +19,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import collections
import hashlib
import os import os
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tensor_tracer_pb2 from tensorflow.python.tpu import tensor_tracer_pb2
@ -53,6 +55,18 @@ _CURRENT_VERSION = 'use-outside-compilation'
_TT_REPORT_PROTO = 'tensor_tracer_report.report_pb' _TT_REPORT_PROTO = 'tensor_tracer_report.report_pb'
def report_proto_path(trace_dir):
"""Returns the path where report proto should be written.
Args:
trace_dir: String denoting the trace directory.
Returns:
A string denoting the path to the report proto.
"""
return os.path.join(trace_dir, _TT_REPORT_PROTO)
def topological_sort(g): def topological_sort(g):
"""Performs topological sort on the given graph. """Performs topological sort on the given graph.
@ -206,6 +220,12 @@ class OpenReportFile(object):
self._report_file.close() self._report_file.close()
def proto_fingerprint(message_proto):
serialized_message = message_proto.SerializeToString()
hasher = hashlib.sha256(serialized_message)
return hasher.hexdigest()
class TTReportHandle(object): class TTReportHandle(object):
"""Utility class responsible from creating a tensor tracer report.""" """Utility class responsible from creating a tensor tracer report."""
@ -255,8 +275,6 @@ class TTReportHandle(object):
key=lambda x: x[1]): key=lambda x: x[1]):
report.config.signatures.append(signature_name) report.config.signatures.append(signature_name)
tf_graph = tensor_trace_order.graph_order.graph
report.graphdef.CopyFrom(tf_graph.as_graph_def())
for tensor in tensor_trace_order.graph_order.tensors: for tensor in tensor_trace_order.graph_order.tensors:
tensor_def = tensor_tracer_pb2.TensorTracerReport.TracedTensorDef() tensor_def = tensor_tracer_pb2.TensorTracerReport.TracedTensorDef()
tensor_def.name = tensor.name tensor_def.name = tensor.name
@ -265,6 +283,11 @@ class TTReportHandle(object):
tensor_def.cache_index = ( tensor_def.cache_index = (
tensor_trace_order.tensorname_to_cache_idx[tensor.name]) tensor_trace_order.tensorname_to_cache_idx[tensor.name])
else: else:
# To prevent small changes affecting the fingerprint calculation, avoid
# writing the untraced tensors to metadata. Fingerprints will be
# different only when the list of the traced tensors are different.
if tt_parameters.use_fingerprint_subdir:
continue
tensor_def.is_traced = False tensor_def.is_traced = False
if tensor.name in tensor_trace_points: if tensor.name in tensor_trace_points:
@ -274,12 +297,17 @@ class TTReportHandle(object):
elif tensor.op.name in self.instrument_records: elif tensor.op.name in self.instrument_records:
tensor_def.explanation = self.instrument_records[tensor.op.name] tensor_def.explanation = self.instrument_records[tensor.op.name]
report.tensordef[tensor.name].CopyFrom(tensor_def) report.tensordef[tensor.name].CopyFrom(tensor_def)
report.fingerprint = proto_fingerprint(report)
logging.info('TensorTracerProto fingerprint is %s.',
report.fingerprint)
tf_graph = tensor_trace_order.graph_order.graph
report.graphdef.CopyFrom(tf_graph.as_graph_def())
return report return report
def write_report_proto(self, report_proto, tt_parameters): def write_report_proto(self, report_proto, tt_parameters):
"""Writes the given report proto under trace_dir.""" """Writes the given report proto under trace_dir."""
gfile.MakeDirs(tt_parameters.trace_dir) gfile.MakeDirs(tt_parameters.trace_dir)
report_path = os.path.join(tt_parameters.trace_dir, _TT_REPORT_PROTO) report_path = report_proto_path(tt_parameters.trace_dir)
with gfile.GFile(report_path, 'wb') as f: with gfile.GFile(report_path, 'wb') as f:
f.write(report_proto.SerializeToString()) f.write(report_proto.SerializeToString())