Tensor tracer refactor.
PiperOrigin-RevId: 250316061
This commit is contained in:
parent
c47ec54211
commit
53027266f0
@ -155,6 +155,7 @@ py_library(
|
||||
"session_support.py",
|
||||
"tensor_tracer.py",
|
||||
"tensor_tracer_flags.py",
|
||||
"tensor_tracer_report.py",
|
||||
"topology.py",
|
||||
"tpu.py",
|
||||
"tpu_feed.py",
|
||||
|
@ -40,13 +40,14 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import tensor_tracer_flags
|
||||
from tensorflow.python.tpu import tensor_tracer_report
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
|
||||
_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
|
||||
_DEVICE_TYPE_TPU = 'tpu'
|
||||
_DEVICE_TYPE_CPU = 'cpu'
|
||||
_TRACE_MODE_PART_TENSOR_SIZE = 3
|
||||
|
||||
_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
|
||||
_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
|
||||
_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
|
||||
@ -62,28 +63,9 @@ _REASON_USER_EXCLUDED = 'not-traced-user-excluded'
|
||||
_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
|
||||
_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
|
||||
_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
|
||||
_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
|
||||
_MARKER_SECTION_END = '!!!!!!! section-end:'
|
||||
_SECTION_NAME_CONFIG = 'configuration'
|
||||
_SECTION_NAME_REASON = 'reason'
|
||||
_SECTION_NAME_OP_LIST = 'op-list'
|
||||
_SECTION_NAME_TENSOR_LIST = 'tensor-list'
|
||||
_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
|
||||
_SECTION_NAME_GRAPH = 'graph'
|
||||
_FIELD_NAME_VERSION = 'version:'
|
||||
_FIELD_NAME_DEVICE = 'device:'
|
||||
_FIELD_NAME_TRACE_MODE = 'trace-mode:'
|
||||
_FIELD_NAME_SUBMODE = 'submode:'
|
||||
_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
|
||||
_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
|
||||
_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
|
||||
_FIELD_NAME_NUM_OPS = 'number-of-ops:'
|
||||
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
|
||||
_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
|
||||
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
|
||||
|
||||
_OUTPUT_STREAM_ESCAPE = 'file://'
|
||||
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
|
||||
_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
|
||||
_TRACE_FILE_NAME = 'trace.all'
|
||||
_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
|
||||
_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
|
||||
@ -191,6 +173,7 @@ def _create_tensor_values_cache(graph, num_tensors):
|
||||
use_resource=True,
|
||||
collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
|
||||
|
||||
|
||||
class TensorTracer(object):
|
||||
"""A software construct for tracing tensor values in a TF graph on TPU.
|
||||
|
||||
@ -298,142 +281,28 @@ class TensorTracer(object):
|
||||
|
||||
return '%d %s'%(op_idx, details)
|
||||
|
||||
@staticmethod
|
||||
def topological_sort(g):
|
||||
"""Performs topological sort on the given graph.
|
||||
|
||||
Args:
|
||||
g: the graph.
|
||||
|
||||
Returns:
|
||||
A pair where the first element indicates if the topological
|
||||
sort succeeded (True if there is no cycle found; False if a
|
||||
cycle is found) and the second element is either the sorted
|
||||
list of nodes or the cycle of nodes found.
|
||||
"""
|
||||
def _is_loop_edge(op):
|
||||
"""Returns true if the op is the end of a while-loop creating a cycle."""
|
||||
return op.type in ['NextIteration']
|
||||
|
||||
def _in_op_degree(op):
|
||||
"""Returns the number of incoming edges to the given op.
|
||||
|
||||
The edge calculation skips the edges that come from 'NextIteration' ops.
|
||||
NextIteration creates a cycle in the graph. We break cycles by treating
|
||||
this op as 'sink' and ignoring all outgoing edges from it.
|
||||
Args:
|
||||
op: Tf.Operation
|
||||
Returns:
|
||||
the number of incoming edges.
|
||||
"""
|
||||
count = 0
|
||||
for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
|
||||
if not _is_loop_edge(op):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
sorted_ops = []
|
||||
op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
|
||||
|
||||
frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
|
||||
frontier.sort(key=lambda op: op.name)
|
||||
while frontier:
|
||||
op = frontier.pop()
|
||||
# Remove the op from graph, and remove its outgoing edges.
|
||||
sorted_ops.append(op)
|
||||
if _is_loop_edge(op):
|
||||
continue
|
||||
# pylint: disable=protected-access
|
||||
consumers = list(op._control_outputs)
|
||||
# pylint: enable=protected-access
|
||||
for out_tensor in op.outputs:
|
||||
consumers += [consumer_op for consumer_op in out_tensor.consumers()]
|
||||
consumers.sort(key=lambda op: op.name)
|
||||
for consumer in consumers:
|
||||
# For each deleted edge shift the bucket of the vertex.
|
||||
op_in_degree[consumer] -= 1
|
||||
if op_in_degree[consumer] == 0:
|
||||
frontier.append(consumer)
|
||||
if op_in_degree[consumer] < 0:
|
||||
raise ValueError('consumer:%s degree mismatch'%consumer.name)
|
||||
|
||||
left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
|
||||
if left_ops:
|
||||
return (False, left_ops)
|
||||
else:
|
||||
assert len(g.get_operations()) == len(sorted_ops)
|
||||
return (True, sorted_ops)
|
||||
|
||||
@staticmethod
|
||||
def _make_op_and_tensor_maps(op_list):
|
||||
"""Creates various maps and lists from op_list.
|
||||
|
||||
Args:
|
||||
op_list: a list of Ops
|
||||
|
||||
Returns:
|
||||
opname_idx_map: a map from Op's name to its index in op_list.
|
||||
tensor_list: a list of output tensors of the Ops in op_list.
|
||||
tensorname_idx_map: a map from output tensor name to its index
|
||||
in tensor_list.
|
||||
"""
|
||||
|
||||
opname_idx_map = {}
|
||||
tensor_list = []
|
||||
tensorname_idx_map = {}
|
||||
for op_id, op in enumerate(op_list):
|
||||
if op.name in opname_idx_map:
|
||||
raise ValueError('Duplicated Op name: %s'%op.name)
|
||||
opname_idx_map[op.name] = op_id
|
||||
for output_tensor in op.outputs:
|
||||
if output_tensor.name not in tensorname_idx_map:
|
||||
tensor_list.append(output_tensor)
|
||||
tensorname_idx_map[output_tensor.name] = len(tensor_list)-1
|
||||
return (opname_idx_map, tensor_list, tensorname_idx_map)
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a TensorTracer.
|
||||
|
||||
Sets the various member fields from the flags (if given) or the defaults.
|
||||
"""
|
||||
self._parameters = tensor_tracer_flags.TTParameters()
|
||||
self._set_report_file()
|
||||
self._version = 'use-outside-compilation'
|
||||
self._device_type = None
|
||||
self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE
|
||||
self._instrument_records = {}
|
||||
self._num_replicas = None
|
||||
self._num_replicas_per_host = None
|
||||
self._num_hosts = None
|
||||
self._replica_id = None
|
||||
self._tt_config = tensor_tracer_report.TensorTracerConfig()
|
||||
self._parameters = tensor_tracer_flags.TTParameters()
|
||||
self._included_op_full_names = set()
|
||||
|
||||
def _add_replica_id_to_graph(self):
|
||||
"""Adds nodes for computing the replica ID to the graph."""
|
||||
|
||||
if self._num_replicas:
|
||||
if self._tt_config.num_replicas:
|
||||
with ops.control_dependencies(None):
|
||||
# Uses None as dependency to run outside of TPU graph rewrites.
|
||||
self._replica_id = tpu_ops.tpu_replicated_input(
|
||||
list(range(self._num_replicas)),
|
||||
list(range(self._tt_config.num_replicas)),
|
||||
name='tt_replica_id')
|
||||
else:
|
||||
self._replica_id = 'unknown'
|
||||
|
||||
def _set_report_file(self):
|
||||
"""Sets the path of the output report file."""
|
||||
if not self._parameters.report_file_path:
|
||||
self._report_file = None
|
||||
return
|
||||
try:
|
||||
self._report_file = gfile.Open(self._parameters.report_file_path, 'w')
|
||||
except IOError as e:
|
||||
raise e
|
||||
|
||||
def _close_report_file(self):
|
||||
if self._report_file:
|
||||
self._report_file.close()
|
||||
|
||||
def _inside_op_range(self, idx):
|
||||
"""Return True if the given index is inside the selected range."""
|
||||
|
||||
@ -519,107 +388,6 @@ class TensorTracer(object):
|
||||
indices = constant_op.constant([cache_idx])
|
||||
return state_ops.scatter_update(cache, indices, updates).op
|
||||
|
||||
def _write_report(self, content):
|
||||
"""Writes the given content to the report."""
|
||||
|
||||
line = '%s %s'%(_TRACER_LOG_PREFIX, content)
|
||||
if self._report_file:
|
||||
self._report_file.write(line)
|
||||
else:
|
||||
logging.info(line)
|
||||
|
||||
def _write_config_section(self):
|
||||
"""Writes the config section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE,
|
||||
self._parameters.trace_mode))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE,
|
||||
self._parameters.submode))
|
||||
if self._parameters.included_cores:
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
|
||||
len(self._parameters.included_cores)))
|
||||
else:
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
|
||||
self._num_replicas))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
|
||||
self._num_replicas_per_host))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, self._num_hosts))
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
|
||||
|
||||
def _write_reason_section(self):
|
||||
"""Writes the reason section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON))
|
||||
for key in sorted(self._instrument_records):
|
||||
self._write_report('"%s" %s\n'%(key, self._instrument_records[key]))
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
|
||||
|
||||
def _write_op_list_section(self, op_list):
|
||||
"""Writes the Op-list section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
|
||||
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list)))
|
||||
for i in range(0, len(op_list)):
|
||||
op = op_list[i]
|
||||
line = '%d "%s" %s'%(i, op.name, op.type)
|
||||
for out_tensor in op.outputs:
|
||||
if out_tensor.name not in self._tensorname_idx_map:
|
||||
raise ValueError(
|
||||
'out_tensor %s is not in tensorname_idx_map'%out_tensor.name)
|
||||
line += ' %d'%self._tensorname_idx_map[out_tensor.name]
|
||||
line += '\n'
|
||||
self._write_report(line)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
|
||||
|
||||
def _write_tensor_list_section(self, tensor_list, opname_idx_map):
|
||||
"""Writes the tensor-list section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
|
||||
_SECTION_NAME_TENSOR_LIST))
|
||||
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list)))
|
||||
for i in range(0, len(tensor_list)):
|
||||
tensor = tensor_list[i]
|
||||
line = '%d "%s"'%(i, tensor.name)
|
||||
consumers = tensor.consumers()
|
||||
consumers.sort(key=lambda op: op.name)
|
||||
for consumer_op in consumers:
|
||||
if consumer_op.name not in opname_idx_map:
|
||||
raise ValueError(
|
||||
'consumer_op %s is not in opname_idx_map'%consumer_op.name)
|
||||
line += ' %d'%opname_idx_map[consumer_op.name]
|
||||
line += '\n'
|
||||
self._write_report(line)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
|
||||
_SECTION_NAME_TENSOR_LIST))
|
||||
|
||||
def _write_cache_index_map_section(self):
|
||||
"""Writes the mapping from cache index to tensor index to the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
|
||||
_SECTION_NAME_CACHE_INDEX_MAP))
|
||||
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_CACHE_INDICES,
|
||||
len(self._cache_idx_to_tensor_idx)))
|
||||
for cache_idx in range(0, len(self._cache_idx_to_tensor_idx)):
|
||||
tensor_idx = self._cache_idx_to_tensor_idx[cache_idx]
|
||||
line = '%d %d\n'%(cache_idx, tensor_idx)
|
||||
self._write_report(line)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
|
||||
_SECTION_NAME_CACHE_INDEX_MAP))
|
||||
|
||||
def _write_graph_section(self, succeed, sorted_or_cycle):
|
||||
"""Writes the graph section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED,
|
||||
succeed))
|
||||
l = list(sorted_or_cycle)
|
||||
for i in range(0, len(l)):
|
||||
self._write_report('%d "%s"\n'%(i, l[i].name))
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH))
|
||||
|
||||
def _preprocess_traced_tensor(self, tensor):
|
||||
"""Computes NAN/Norm/Max on TPUs before sending to CPU.
|
||||
|
||||
@ -699,12 +467,12 @@ class TensorTracer(object):
|
||||
'Tensor trace fun for %s is not yet implemented'
|
||||
% self._parameters.trace_mode)
|
||||
|
||||
def _make_tensor_trace_fun(self, tensor_name):
|
||||
def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order):
|
||||
"""Makes the tensor tracing function called by outside compilation.
|
||||
|
||||
Args:
|
||||
tensor_name: name of the tensor being traced.
|
||||
|
||||
tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
|
||||
Returns:
|
||||
A function to be passed as the first argument to outside compilation.
|
||||
|
||||
@ -730,7 +498,7 @@ class TensorTracer(object):
|
||||
"""
|
||||
|
||||
if self._parameters.is_brief_mode():
|
||||
if tensor_name not in self._tensorname_idx_map:
|
||||
if tensor_name not in tensor_trace_order.tensorname_idx_map:
|
||||
raise ValueError(
|
||||
'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
|
||||
msg = '%d'%self._tensorname_idx_map[tensor_name]
|
||||
@ -751,7 +519,7 @@ class TensorTracer(object):
|
||||
def _show_part_tensor(tensor):
|
||||
"""Trace function for printing part of the tensor."""
|
||||
|
||||
return _print_tensor(tensor_name, self._part_tensor_size,
|
||||
return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE,
|
||||
tensor, tensor)
|
||||
|
||||
def _show_full_tensor(tensor):
|
||||
@ -808,47 +576,44 @@ class TensorTracer(object):
|
||||
raise RuntimeError('Tensor trace fun for %s is not yet implemented'
|
||||
%self._parameters.trace_mode)
|
||||
|
||||
def _skip_op(self, op_id, op, user_included, user_excluded,
|
||||
in_exec_path=True):
|
||||
def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
|
||||
"""Returns True if we should not trace Op."""
|
||||
|
||||
if TensorTracer.while_loop_op(op):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_WHILELOOP_OP)
|
||||
report_handler.instrument_op(
|
||||
op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
|
||||
return True
|
||||
if TensorTracer.unsafe_op(op):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_UNSAFE_OP)
|
||||
report_handler.instrument_op(
|
||||
op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP))
|
||||
return True
|
||||
if TensorTracer.device_mismatch(self._device_type, op):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_DEVICE_MISMATCH)
|
||||
if TensorTracer.device_mismatch(self._tt_config.device_type, op):
|
||||
report_handler.instrument_op(
|
||||
op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH))
|
||||
return True
|
||||
if not in_exec_path:
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_NOT_EXECUTED)
|
||||
if op not in ops_in_exec_path:
|
||||
report_handler.instrument_op(
|
||||
op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED))
|
||||
return True
|
||||
|
||||
if not self._inside_op_range(op_id):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_OUTSIDE_OP_RANGE)
|
||||
report_handler.instrument_op(
|
||||
op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE))
|
||||
return True
|
||||
if self._less_interesting_op(op):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_LESS_INTERESTING_OP)
|
||||
report_handler.instrument_op(
|
||||
op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP))
|
||||
return True
|
||||
if user_included:
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_USER_INCLUDED)
|
||||
if self._is_user_included_op(op):
|
||||
report_handler.instrument_op(
|
||||
op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
|
||||
return False
|
||||
if user_excluded:
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_USER_EXCLUDED)
|
||||
if self._is_user_excluded_op(op):
|
||||
report_handler.instrument_op(
|
||||
op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
|
||||
return True
|
||||
return False
|
||||
|
||||
def _skip_tensor(self, op_id, out_tensor, user_included,
|
||||
user_excluded):
|
||||
def _skip_tensor(self, op_id, out_tensor, report_handler):
|
||||
"""Returns True if we should not trace out_tensor."""
|
||||
|
||||
# Skips a tensor if the tensor has a non-numeric type.
|
||||
@ -858,22 +623,23 @@ class TensorTracer(object):
|
||||
non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
|
||||
dtypes.string])
|
||||
if out_tensor.dtype in non_numeric_tensor_types:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_NON_NUMERIC_TENSOR)
|
||||
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR))
|
||||
return True
|
||||
# Skip a tensor if it feeds a special while loop op.
|
||||
if [consumer for consumer in out_tensor.consumers() if
|
||||
TensorTracer.while_loop_op(consumer)]:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_FEEDS_WHILELOOP_OP)
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP))
|
||||
return True
|
||||
if user_included:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_USER_INCLUDED)
|
||||
if self._is_user_included_op(out_tensor.op):
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
|
||||
return False
|
||||
if user_excluded:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_USER_EXCLUDED)
|
||||
if self._is_user_excluded_op(out_tensor.op):
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
|
||||
return True
|
||||
if not out_tensor.get_shape().is_fully_defined():
|
||||
# If trace mode is nan-inf, norm or max, then the tensor will be reduced
|
||||
@ -883,33 +649,33 @@ class TensorTracer(object):
|
||||
tensor_tracer_flags.TRACE_MODE_NORM,
|
||||
tensor_tracer_flags.TRACE_MODE_MAX_ABS
|
||||
]:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_TENSOR_GET_TRACED)
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
|
||||
return False
|
||||
else:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_DYNAMIC_SHAPE)
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE))
|
||||
return True
|
||||
rank = len(out_tensor.shape)
|
||||
if rank < 1:
|
||||
# scalar
|
||||
if self._parameters.trace_scalar_ops:
|
||||
if TensorTracer.unsafe_scalar_trace(out_tensor.op):
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_UNSAFE_SCALAR)
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR))
|
||||
return True
|
||||
else:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_SCALAR_GET_TRACED)
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED))
|
||||
return False
|
||||
else:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_SKIP_SCALAR)
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR))
|
||||
return True
|
||||
else:
|
||||
# tensor
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_TENSOR_GET_TRACED)
|
||||
report_handler.instrument_tensor(
|
||||
out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
|
||||
return False
|
||||
|
||||
def _filter_execution_path_operations(self, operations, fetches):
|
||||
@ -951,42 +717,25 @@ class TensorTracer(object):
|
||||
traverse_stack.append(input_op)
|
||||
return execution_path_operations
|
||||
|
||||
def _determine_traced_tensors(self, graph, ops_in_exec_path):
|
||||
"""Determines the tensors that will be traced."""
|
||||
def _determine_and_instrument_traced_tensors(self, graph_order,
|
||||
ops_in_exec_path,
|
||||
tensor_trace_points,
|
||||
report_handler):
|
||||
"""Determines the tensors to trace and instruments the trace details."""
|
||||
|
||||
self._traced_tensorname_to_cache_idx_map = {}
|
||||
self._cache_idx_to_tensor_idx = []
|
||||
operations = graph.get_operations()
|
||||
checkpoint_operations = self._get_checkpoints(graph)
|
||||
for op_id, op in enumerate(operations):
|
||||
if checkpoint_operations and op.name not in checkpoint_operations:
|
||||
traced_tensors = []
|
||||
checkpoint_operations = set([tensor.op
|
||||
for (tensor, _) in tensor_trace_points])
|
||||
for op_id, op in enumerate(graph_order.operations):
|
||||
if checkpoint_operations and op not in checkpoint_operations:
|
||||
continue
|
||||
user_included = self._is_user_included_op(op)
|
||||
user_excluded = self._is_user_excluded_op(op)
|
||||
in_exec_path = op in ops_in_exec_path
|
||||
if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path):
|
||||
if self._skip_op(op_id, op, ops_in_exec_path, report_handler):
|
||||
continue
|
||||
for i in range(len(op.outputs)):
|
||||
out_tensor = op.outputs[i]
|
||||
if self._skip_tensor(op_id, out_tensor, user_included,
|
||||
user_excluded):
|
||||
continue
|
||||
tensor_name = out_tensor.name
|
||||
if tensor_name in self._traced_tensorname_to_cache_idx_map:
|
||||
raise ValueError(
|
||||
'Tensor name %s should not be already in '
|
||||
'traced_tensorname_to_cache_idx_map'%tensor_name)
|
||||
if tensor_name not in self._tensorname_idx_map:
|
||||
raise ValueError(
|
||||
'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
|
||||
tensor_idx = self._tensorname_idx_map[tensor_name]
|
||||
cache_idx = len(self._traced_tensorname_to_cache_idx_map)
|
||||
self._traced_tensorname_to_cache_idx_map[tensor_name] = cache_idx
|
||||
self._cache_idx_to_tensor_idx.append(tensor_idx)
|
||||
if len(self._traced_tensorname_to_cache_idx_map) != len(
|
||||
self._cache_idx_to_tensor_idx):
|
||||
raise RuntimeError('len(self._traced_tensorname_to_cache_idx_map) != '
|
||||
'len(self._cache_idx_to_tensor_idx')
|
||||
if not self._skip_tensor(op_id, out_tensor, report_handler):
|
||||
traced_tensors.append(out_tensor)
|
||||
return traced_tensors
|
||||
|
||||
def _check_trace_files(self):
|
||||
"""Checks if any requirements for trace files are satisfied."""
|
||||
@ -995,7 +744,7 @@ class TensorTracer(object):
|
||||
# traces will be written to stderr. No need to check trace files.
|
||||
return
|
||||
if _trace_files_need_precreated(self._parameters.trace_dir):
|
||||
for replica_id in range(0, self._num_replicas):
|
||||
for replica_id in range(0, self._tt_config.num_replicas):
|
||||
trace_file_path = os.path.join(
|
||||
self._parameters.trace_dir,
|
||||
_COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id
|
||||
@ -1009,56 +758,25 @@ class TensorTracer(object):
|
||||
if not gfile.Exists(self._parameters.trace_dir):
|
||||
raise RuntimeError('Failed to create %s'%self._parameters.trace_dir)
|
||||
|
||||
def _pre_tracing(self, graph, fetches):
|
||||
def _determine_trace_and_create_report(self, graph, ops_in_exec_path):
|
||||
"""Work needs to be done prior to TPU or CPU tracing."""
|
||||
|
||||
self._check_trace_files()
|
||||
operations = graph.get_operations()
|
||||
(opname_idx_map, tensor_list, self._tensorname_idx_map) = (
|
||||
TensorTracer._make_op_and_tensor_maps(operations))
|
||||
self._write_config_section()
|
||||
self._write_op_list_section(operations)
|
||||
self._write_tensor_list_section(tensor_list, opname_idx_map)
|
||||
# Filter out the operations that won't be executed.
|
||||
# if fetches=None, then ops_in_exec_path = set(operations)
|
||||
ops_in_exec_path = self._filter_execution_path_operations(operations,
|
||||
fetches)
|
||||
self._determine_traced_tensors(graph, ops_in_exec_path)
|
||||
self._write_cache_index_map_section()
|
||||
# Does the topological sort before adding any nodes to the graph.
|
||||
(succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
|
||||
|
||||
graph_order = tensor_tracer_report.sort_tensors_and_ops(graph)
|
||||
tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION)
|
||||
|
||||
report_handler = tensor_tracer_report.TTReportHandle()
|
||||
traced_tensors = self._determine_and_instrument_traced_tensors(
|
||||
graph_order, ops_in_exec_path, tensor_trace_points, report_handler)
|
||||
|
||||
tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
|
||||
traced_tensors)
|
||||
if self._use_tensor_values_cache():
|
||||
_create_tensor_values_cache(graph,
|
||||
len(self._cache_idx_to_tensor_idx))
|
||||
return (ops_in_exec_path, succeed, sorted_or_cycle)
|
||||
|
||||
def _post_tracing(self, succeed, sorted_or_cycle):
|
||||
"""Work needs to be done after TPU or CPU tracing."""
|
||||
|
||||
self._write_reason_section()
|
||||
self._write_graph_section(succeed, sorted_or_cycle)
|
||||
self._close_report_file()
|
||||
|
||||
def _get_checkpoints(self, graph):
|
||||
"""Returns the list of Ops that produce the tensors traced with API.
|
||||
|
||||
Args:
|
||||
graph: the graph of Ops.
|
||||
|
||||
Returns:
|
||||
A set of operation names which should be traced.
|
||||
"""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
|
||||
_TENSOR_TRACER_CHECKPOINT))
|
||||
checkpoint_operations = set()
|
||||
tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION)
|
||||
for (tensor, checkpoint_name) in tensor_tracer_variables:
|
||||
self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
|
||||
checkpoint_operations.add(tensor.op.name)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
|
||||
_TENSOR_TRACER_CHECKPOINT))
|
||||
return checkpoint_operations
|
||||
_create_tensor_values_cache(graph, len(traced_tensors))
|
||||
report_handler.create_report(self._tt_config, self._parameters,
|
||||
tensor_trace_order, tensor_trace_points)
|
||||
return tensor_trace_order
|
||||
|
||||
def _generate_flush_cache_op(self, graph, start_replica, on_tpu):
|
||||
"""Generates an Op that will flush the cache to file.
|
||||
@ -1085,9 +803,9 @@ class TensorTracer(object):
|
||||
else:
|
||||
replica_id_str = '%d'%replica_id
|
||||
if self._parameters.trace_dir:
|
||||
output_path = os.path.join(self._parameters.trace_dir,
|
||||
_COMPACT_TRACE_FILE_PREFIX) \
|
||||
+ replica_id_str
|
||||
output_path = (os.path.join(self._parameters.trace_dir,
|
||||
_COMPACT_TRACE_FILE_PREFIX)
|
||||
+ replica_id_str)
|
||||
output_stream = _OUTPUT_STREAM_ESCAPE + output_path
|
||||
else:
|
||||
output_stream = sys.stderr
|
||||
@ -1153,7 +871,7 @@ class TensorTracer(object):
|
||||
with ops.control_dependencies(op_fetches +
|
||||
[tensor.op for tensor in tensor_fetches]):
|
||||
flush_cache_op_list = []
|
||||
for host in range(self._num_hosts):
|
||||
for host in range(self._tt_config.num_hosts):
|
||||
start_replica = host * 8
|
||||
flush_op = self._generate_flush_cache_op(graph, start_replica, on_tpu)
|
||||
flush_cache_op_list.append(flush_op)
|
||||
@ -1233,8 +951,8 @@ class TensorTracer(object):
|
||||
on_tpu=True):
|
||||
"""Commong tracing function for both CPU and TPUs.
|
||||
|
||||
The caller function should set _device_type, _num_replicas,
|
||||
_num_replicas_per_host, _num_hosts and _replica_id before calling
|
||||
The caller function should set device_type, num_replicas,
|
||||
num_replicas_per_host, num_hosts and replica_id before calling
|
||||
_trace_execution.
|
||||
|
||||
|
||||
@ -1266,15 +984,19 @@ class TensorTracer(object):
|
||||
return math_ops.cast(tensor, dtypes.float32)
|
||||
return tensor
|
||||
|
||||
TensorTracer.check_device_type(self._device_type)
|
||||
TensorTracer.check_device_type(self._tt_config.device_type)
|
||||
# Check in_tensor_fetches, and op_fetches and convert them to lists.
|
||||
processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
|
||||
op_fetches = self._process_op_fetches(op_fetches)
|
||||
all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]
|
||||
|
||||
# Filter the set of ops that will be executed, and topological sort.
|
||||
(exec_op_set, succeed, sorted_or_cycle) = self._pre_tracing(graph,
|
||||
all_fetches)
|
||||
# Filter out the operations that won't be executed.
|
||||
# if fetches=None, then ops_in_exec_path = set(operations)
|
||||
exec_op_set = self._filter_execution_path_operations(graph.get_operations(),
|
||||
all_fetches)
|
||||
# Write report file, and determine the traced tensors.
|
||||
tensor_trace_order = self._determine_trace_and_create_report(
|
||||
graph, exec_op_set)
|
||||
|
||||
tensor_fetch_set = set(processed_t_fetches)
|
||||
tracing_ops = []
|
||||
@ -1290,7 +1012,7 @@ class TensorTracer(object):
|
||||
for i in range(len(op.outputs)):
|
||||
out_tensor = op.outputs[i]
|
||||
tensor_name = out_tensor.name
|
||||
if tensor_name not in self._traced_tensorname_to_cache_idx_map:
|
||||
if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
|
||||
continue
|
||||
# Create the list of consumers before calling _preprocess_traced_tensor.
|
||||
# Otherwise, adding control input below, will introduce a cycle in the
|
||||
@ -1316,7 +1038,7 @@ class TensorTracer(object):
|
||||
processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor)
|
||||
|
||||
if self._use_tensor_values_cache():
|
||||
cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name]
|
||||
cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
|
||||
trace_op = self._save_tensor_value_to_cache_op(graph,
|
||||
cache_idx,
|
||||
processed_out_tensor)
|
||||
@ -1324,7 +1046,8 @@ class TensorTracer(object):
|
||||
|
||||
def tpu_wrap_trace_fn(tensor, out_tensor_name):
|
||||
"""Wraps the trace_fn with outside compilation if on TPUs."""
|
||||
tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name)
|
||||
tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name,
|
||||
tensor_trace_order)
|
||||
if on_tpu:
|
||||
return tpu.outside_compilation(tensor_trace_fn, tensor)
|
||||
else:
|
||||
@ -1374,7 +1097,6 @@ class TensorTracer(object):
|
||||
processed_t_fetches,
|
||||
op_fetches,
|
||||
on_tpu=on_tpu)
|
||||
self._post_tracing(succeed, sorted_or_cycle)
|
||||
# processed_t_fetches is a list at this point. Convert it to the same
|
||||
# format as given in tensor_fetches.
|
||||
return self._convert_fetches_to_input_format(tensor_fetches,
|
||||
@ -1414,21 +1136,23 @@ class TensorTracer(object):
|
||||
return tensor_fetches
|
||||
else:
|
||||
TensorTracer._traced_graphs.add(graph)
|
||||
self._device_type = _DEVICE_TYPE_TPU
|
||||
self._num_replicas = num_replicas
|
||||
self._num_replicas_per_host = num_replicas_per_host
|
||||
self._num_hosts = num_hosts
|
||||
if self._num_replicas is not None:
|
||||
if self._num_replicas_per_host is None:
|
||||
self._num_replicas_per_host = 8
|
||||
if self._num_hosts is None:
|
||||
self._num_hosts = num_replicas // self._num_replicas_per_host + \
|
||||
(num_replicas % self._num_replicas_per_host > 0)
|
||||
|
||||
if self._num_replicas_per_host > 8:
|
||||
self._tt_config.device_type = _DEVICE_TYPE_TPU
|
||||
self._tt_config.num_replicas = num_replicas
|
||||
self._tt_config.num_replicas_per_host = num_replicas_per_host
|
||||
self._tt_config.num_hosts = num_hosts
|
||||
if self._tt_config.num_replicas is not None:
|
||||
if self._tt_config.num_replicas_per_host is None:
|
||||
self._tt_config.num_replicas_per_host = 8
|
||||
if self._tt_config.num_hosts is None:
|
||||
self._tt_config.num_hosts = (
|
||||
num_replicas // self._tt_config.num_replicas_per_host +
|
||||
(num_replicas % self._tt_config.num_replicas_per_host > 0))
|
||||
|
||||
if self._tt_config.num_replicas_per_host > 8:
|
||||
# Checks for the assumption in _generate_flush_cache_op().
|
||||
raise RuntimeError('num_replicas_per_host (%d) is '
|
||||
'greater than 8'%self._num_replicas_per_host)
|
||||
'greater than 8'%self._tt_config.num_replicas_per_host)
|
||||
if self._parameters.graph_dump_path:
|
||||
graph_io.write_graph(graph, self._parameters.graph_dump_path,
|
||||
'graph_before_tt.pbtxt')
|
||||
@ -1467,10 +1191,10 @@ class TensorTracer(object):
|
||||
else:
|
||||
TensorTracer._traced_graphs.add(graph)
|
||||
|
||||
self._device_type = _DEVICE_TYPE_CPU
|
||||
self._num_replicas = 1
|
||||
self._num_replicas_per_host = 1
|
||||
self._num_hosts = 1
|
||||
self._tt_config.device_type = _DEVICE_TYPE_CPU
|
||||
self._tt_config.num_replicas = 1
|
||||
self._tt_config.num_replicas_per_host = 1
|
||||
self._tt_config.num_hosts = 1
|
||||
self._replica_id = 0
|
||||
if self._parameters.graph_dump_path:
|
||||
graph_io.write_graph(graph, self._parameters.graph_dump_path,
|
||||
|
341
tensorflow/python/tpu/tensor_tracer_report.py
Normal file
341
tensorflow/python/tpu/tensor_tracer_report.py
Normal file
@ -0,0 +1,341 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========================================================================
|
||||
"""Tensor Tracer report generation utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
|
||||
_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
|
||||
_MARKER_SECTION_END = '!!!!!!! section-end:'
|
||||
|
||||
_SECTION_NAME_CONFIG = 'configuration'
|
||||
_SECTION_NAME_REASON = 'reason'
|
||||
_SECTION_NAME_OP_LIST = 'op-list'
|
||||
_SECTION_NAME_TENSOR_LIST = 'tensor-list'
|
||||
_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
|
||||
_SECTION_NAME_GRAPH = 'graph'
|
||||
_SECTION_NAME_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
|
||||
|
||||
_FIELD_NAME_VERSION = 'version:'
|
||||
_FIELD_NAME_DEVICE = 'device:'
|
||||
_FIELD_NAME_TRACE_MODE = 'trace-mode:'
|
||||
_FIELD_NAME_SUBMODE = 'submode:'
|
||||
_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
|
||||
_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
|
||||
_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
|
||||
_FIELD_NAME_NUM_OPS = 'number-of-ops:'
|
||||
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
|
||||
_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
|
||||
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
|
||||
|
||||
_CURRENT_VERSION = 'use-outside-compilation'
|
||||
|
||||
|
||||
def topological_sort(g):
|
||||
"""Performs topological sort on the given graph.
|
||||
|
||||
Args:
|
||||
g: the graph.
|
||||
|
||||
Returns:
|
||||
A pair where the first element indicates if the topological
|
||||
sort succeeded (True if there is no cycle found; False if a
|
||||
cycle is found) and the second element is either the sorted
|
||||
list of nodes or the cycle of nodes found.
|
||||
"""
|
||||
def _is_loop_edge(op):
|
||||
"""Returns true if the op is the end of a while-loop creating a cycle."""
|
||||
return op.type in ['NextIteration']
|
||||
|
||||
def _in_op_degree(op):
|
||||
"""Returns the number of incoming edges to the given op.
|
||||
|
||||
The edge calculation skips the edges that come from 'NextIteration' ops.
|
||||
NextIteration creates a cycle in the graph. We break cycles by treating
|
||||
this op as 'sink' and ignoring all outgoing edges from it.
|
||||
Args:
|
||||
op: Tf.Operation
|
||||
Returns:
|
||||
the number of incoming edges.
|
||||
"""
|
||||
count = 0
|
||||
for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
|
||||
if not _is_loop_edge(op):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
sorted_ops = []
|
||||
op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
|
||||
|
||||
frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
|
||||
frontier.sort(key=lambda op: op.name)
|
||||
while frontier:
|
||||
op = frontier.pop()
|
||||
# Remove the op from graph, and remove its outgoing edges.
|
||||
sorted_ops.append(op)
|
||||
if _is_loop_edge(op):
|
||||
continue
|
||||
# pylint: disable=protected-access
|
||||
consumers = list(op._control_outputs)
|
||||
# pylint: enable=protected-access
|
||||
for out_tensor in op.outputs:
|
||||
consumers += [consumer_op for consumer_op in out_tensor.consumers()]
|
||||
consumers.sort(key=lambda op: op.name)
|
||||
for consumer in consumers:
|
||||
# For each deleted edge shift the bucket of the vertex.
|
||||
op_in_degree[consumer] -= 1
|
||||
if op_in_degree[consumer] == 0:
|
||||
frontier.append(consumer)
|
||||
if op_in_degree[consumer] < 0:
|
||||
raise ValueError('consumer:%s degree mismatch'%consumer.name)
|
||||
|
||||
left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
|
||||
if left_ops:
|
||||
return (True, left_ops)
|
||||
else:
|
||||
assert len(g.get_operations()) == len(sorted_ops)
|
||||
return (False, sorted_ops)
|
||||
|
||||
|
||||
class TensorTracerConfig(object):
|
||||
"""Tensor Tracer config object."""
|
||||
|
||||
def __init__(self):
|
||||
self.version = _CURRENT_VERSION
|
||||
self.device_type = None
|
||||
self.num_replicas = None
|
||||
self.num_replicas_per_host = None
|
||||
self.num_hosts = None
|
||||
|
||||
|
||||
class TensorTraceOrder(object):
|
||||
"""Class that is responsible from storing the trace-id of the tensors."""
|
||||
|
||||
def __init__(self, graph_order, traced_tensors):
|
||||
self.graph_order = graph_order
|
||||
self.traced_tensors = traced_tensors
|
||||
self._create_tensor_maps()
|
||||
|
||||
def _create_tensor_maps(self):
|
||||
"""Creates tensor to cache id maps."""
|
||||
self.tensorname_to_cache_idx = {}
|
||||
self.cache_idx_to_tensor_idx = []
|
||||
for out_tensor in self.traced_tensors:
|
||||
tensor_name = out_tensor.name
|
||||
if tensor_name in self.tensorname_to_cache_idx:
|
||||
raise ValueError(
|
||||
'Tensor name %s should not be already in '
|
||||
'tensorname_to_cache_idx'%tensor_name)
|
||||
if tensor_name not in self.graph_order.tensor_to_idx:
|
||||
raise ValueError(
|
||||
'Tensor name %s is not in the tensor_to_idx'%tensor_name)
|
||||
tensor_idx = self.graph_order.tensor_to_idx[tensor_name]
|
||||
cache_idx = len(self.tensorname_to_cache_idx)
|
||||
self.tensorname_to_cache_idx[tensor_name] = cache_idx
|
||||
self.cache_idx_to_tensor_idx.append(tensor_idx)
|
||||
if len(self.tensorname_to_cache_idx) != len(
|
||||
self.cache_idx_to_tensor_idx):
|
||||
raise RuntimeError('len(self.tensorname_to_cache_idx) != '
|
||||
'len(self.cache_idx_to_tensor_idx')
|
||||
|
||||
|
||||
def sort_tensors_and_ops(graph):
|
||||
"""Returns a wrapper that has consistent tensor and op orders."""
|
||||
graph_wrapper = collections.namedtuple('GraphWrapper',
|
||||
['graph', 'operations', 'op_to_idx',
|
||||
'tensors', 'tensor_to_idx',
|
||||
'contains_cycle',
|
||||
'topological_order_or_cycle'])
|
||||
operations = graph.get_operations()
|
||||
op_to_idx = {op.name: index for index, op
|
||||
in enumerate(operations)}
|
||||
tensors = []
|
||||
for op in operations:
|
||||
tensors.extend(op.outputs)
|
||||
tensor_to_idx = {tensor.name: index for index, tensor in
|
||||
enumerate(tensors)}
|
||||
contains_cycle, topological_order_or_cycle = topological_sort(graph)
|
||||
return graph_wrapper(graph=graph, operations=operations, op_to_idx=op_to_idx,
|
||||
tensors=tensors, tensor_to_idx=tensor_to_idx,
|
||||
contains_cycle=contains_cycle,
|
||||
topological_order_or_cycle=topological_order_or_cycle)
|
||||
|
||||
|
||||
class OpenReportFile(object):
|
||||
"""Context manager for writing report file."""
|
||||
|
||||
def __init__(self, tt_parameters):
|
||||
if not tt_parameters.report_file_path:
|
||||
self._report_file = None
|
||||
return
|
||||
try:
|
||||
self._report_file = gfile.Open(tt_parameters.report_file_path, 'w')
|
||||
except IOError as e:
|
||||
raise e
|
||||
|
||||
def __enter__(self):
|
||||
return self._report_file
|
||||
|
||||
def __exit__(self, unused_type, unused_value, unused_traceback):
|
||||
if self._report_file:
|
||||
self._report_file.close()
|
||||
|
||||
|
||||
class TTReportHandle(object):
|
||||
"""Utility class responsible from creating a tensor tracer report."""
|
||||
|
||||
def __init__(self):
|
||||
self.instrument_records = {}
|
||||
self._report_file = None
|
||||
|
||||
def instrument(self, name, explanation):
|
||||
self.instrument_records[name] = explanation
|
||||
|
||||
def instrument_op(self, op, explanation):
|
||||
self.instrument(op.name, explanation)
|
||||
|
||||
def instrument_tensor(self, tensor, explanation):
|
||||
self.instrument(tensor.name, explanation)
|
||||
|
||||
def create_report(self, tt_config, tt_parameters,
|
||||
tensor_trace_order, tensor_trace_points):
|
||||
"""Creates a report file and writes the trace information."""
|
||||
with OpenReportFile(tt_parameters) as self._report_file:
|
||||
self._write_config_section(tt_config, tt_parameters)
|
||||
self._write_op_list_section(tensor_trace_order.graph_order)
|
||||
self._write_tensor_list_section(tensor_trace_order.graph_order)
|
||||
self._write_trace_points(tensor_trace_points)
|
||||
self._write_cache_index_map_section(tensor_trace_order)
|
||||
self._write_reason_section()
|
||||
self._write_graph_section(tensor_trace_order.graph_order)
|
||||
|
||||
def _write_trace_points(self, tensor_trace_points):
|
||||
"""Writes the list of checkpoints."""
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
|
||||
_SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
|
||||
for (tensor, checkpoint_name) in tensor_trace_points:
|
||||
self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
|
||||
_SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
|
||||
|
||||
def _write_report(self, content):
|
||||
"""Writes the given content to the report."""
|
||||
|
||||
line = '%s %s'%(_TRACER_LOG_PREFIX, content)
|
||||
if self._report_file:
|
||||
self._report_file.write(line)
|
||||
else:
|
||||
logging.info(line)
|
||||
|
||||
def _write_config_section(self, tt_config, tt_parameters):
|
||||
"""Writes the config section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, tt_config.version))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, tt_config.device_type))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE,
|
||||
tt_parameters.trace_mode))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE,
|
||||
tt_parameters.submode))
|
||||
if tt_parameters.included_cores:
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
|
||||
len(tt_parameters.included_cores)))
|
||||
else:
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
|
||||
tt_config.num_replicas))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
|
||||
tt_config.num_replicas_per_host))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, tt_config.num_hosts))
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
|
||||
|
||||
def _write_reason_section(self):
|
||||
"""Writes the reason section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON))
|
||||
for key in sorted(self.instrument_records):
|
||||
self._write_report('"%s" %s\n'%(key, self.instrument_records[key]))
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
|
||||
|
||||
def _write_op_list_section(self, graph_order):
|
||||
"""Writes the Op-list section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
|
||||
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS,
|
||||
len(graph_order.operations)))
|
||||
for i in range(0, len(graph_order.operations)):
|
||||
op = graph_order.operations[i]
|
||||
line = '%d "%s" %s'%(i, op.name, op.type)
|
||||
for out_tensor in op.outputs:
|
||||
if out_tensor.name not in graph_order.tensor_to_idx:
|
||||
raise ValueError(
|
||||
'out_tensor %s is not in tensor_to_idx'%out_tensor.name)
|
||||
line += ' %d'%graph_order.tensor_to_idx[out_tensor.name]
|
||||
line += '\n'
|
||||
self._write_report(line)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
|
||||
|
||||
def _write_tensor_list_section(self, graph_order):
|
||||
"""Writes the tensor-list section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
|
||||
_SECTION_NAME_TENSOR_LIST))
|
||||
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS,
|
||||
len(graph_order.tensors)))
|
||||
for i in range(0, len(graph_order.tensors)):
|
||||
tensor = graph_order.tensors[i]
|
||||
line = '%d "%s"'%(i, tensor.name)
|
||||
consumers = tensor.consumers()
|
||||
consumers.sort(key=lambda op: op.name)
|
||||
for consumer_op in consumers:
|
||||
if consumer_op.name not in graph_order.op_to_idx:
|
||||
raise ValueError(
|
||||
'consumer_op %s is not in op_to_idx'%consumer_op.name)
|
||||
line += ' %d'%graph_order.op_to_idx[consumer_op.name]
|
||||
line += '\n'
|
||||
self._write_report(line)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
|
||||
_SECTION_NAME_TENSOR_LIST))
|
||||
|
||||
def _write_cache_index_map_section(self, tensor_trace_order):
|
||||
"""Writes the mapping from cache index to tensor index to the report."""
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
|
||||
_SECTION_NAME_CACHE_INDEX_MAP))
|
||||
self._write_report('%s %d\n'%(
|
||||
_FIELD_NAME_NUM_CACHE_INDICES,
|
||||
len(tensor_trace_order.cache_idx_to_tensor_idx)))
|
||||
for cache_idx in range(0, len(tensor_trace_order.cache_idx_to_tensor_idx)):
|
||||
tensor_idx = tensor_trace_order.cache_idx_to_tensor_idx[cache_idx]
|
||||
line = '%d %d\n'%(cache_idx, tensor_idx)
|
||||
self._write_report(line)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
|
||||
_SECTION_NAME_CACHE_INDEX_MAP))
|
||||
|
||||
def _write_graph_section(self, graph_order):
|
||||
"""Writes the graph section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED,
|
||||
not graph_order.contains_cycle))
|
||||
l = list(graph_order.topological_order_or_cycle)
|
||||
for i in range(0, len(l)):
|
||||
self._write_report('%d "%s"\n'%(i, l[i].name))
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH))
|
Loading…
x
Reference in New Issue
Block a user