Tensor tracer refactor.

PiperOrigin-RevId: 250316061
This commit is contained in:
A. Unique TensorFlower 2019-05-28 10:44:19 -07:00 committed by TensorFlower Gardener
parent c47ec54211
commit 53027266f0
3 changed files with 465 additions and 399 deletions

View File

@ -155,6 +155,7 @@ py_library(
"session_support.py",
"tensor_tracer.py",
"tensor_tracer_flags.py",
"tensor_tracer_report.py",
"topology.py",
"tpu.py",
"tpu_feed.py",

View File

@ -40,13 +40,14 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tensor_tracer_flags
from tensorflow.python.tpu import tensor_tracer_report
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu.ops import tpu_ops
_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
_DEVICE_TYPE_TPU = 'tpu'
_DEVICE_TYPE_CPU = 'cpu'
_TRACE_MODE_PART_TENSOR_SIZE = 3
_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
@ -62,28 +63,9 @@ _REASON_USER_EXCLUDED = 'not-traced-user-excluded'
_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
_MARKER_SECTION_END = '!!!!!!! section-end:'
_SECTION_NAME_CONFIG = 'configuration'
_SECTION_NAME_REASON = 'reason'
_SECTION_NAME_OP_LIST = 'op-list'
_SECTION_NAME_TENSOR_LIST = 'tensor-list'
_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
_SECTION_NAME_GRAPH = 'graph'
_FIELD_NAME_VERSION = 'version:'
_FIELD_NAME_DEVICE = 'device:'
_FIELD_NAME_TRACE_MODE = 'trace-mode:'
_FIELD_NAME_SUBMODE = 'submode:'
_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
_FIELD_NAME_NUM_OPS = 'number-of-ops:'
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
_OUTPUT_STREAM_ESCAPE = 'file://'
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
_TRACE_FILE_NAME = 'trace.all'
_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
@ -191,6 +173,7 @@ def _create_tensor_values_cache(graph, num_tensors):
use_resource=True,
collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
class TensorTracer(object):
"""A software construct for tracing tensor values in a TF graph on TPU.
@ -298,142 +281,28 @@ class TensorTracer(object):
return '%d %s'%(op_idx, details)
@staticmethod
def topological_sort(g):
"""Performs topological sort on the given graph.
Args:
g: the graph.
Returns:
A pair where the first element indicates if the topological
sort succeeded (True if there is no cycle found; False if a
cycle is found) and the second element is either the sorted
list of nodes or the cycle of nodes found.
"""
def _is_loop_edge(op):
"""Returns true if the op is the end of a while-loop creating a cycle."""
return op.type in ['NextIteration']
def _in_op_degree(op):
"""Returns the number of incoming edges to the given op.
The edge calculation skips the edges that come from 'NextIteration' ops.
NextIteration creates a cycle in the graph. We break cycles by treating
this op as 'sink' and ignoring all outgoing edges from it.
Args:
op: Tf.Operation
Returns:
the number of incoming edges.
"""
count = 0
for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
if not _is_loop_edge(op):
count += 1
return count
sorted_ops = []
op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
frontier.sort(key=lambda op: op.name)
while frontier:
op = frontier.pop()
# Remove the op from graph, and remove its outgoing edges.
sorted_ops.append(op)
if _is_loop_edge(op):
continue
# pylint: disable=protected-access
consumers = list(op._control_outputs)
# pylint: enable=protected-access
for out_tensor in op.outputs:
consumers += [consumer_op for consumer_op in out_tensor.consumers()]
consumers.sort(key=lambda op: op.name)
for consumer in consumers:
# For each deleted edge shift the bucket of the vertex.
op_in_degree[consumer] -= 1
if op_in_degree[consumer] == 0:
frontier.append(consumer)
if op_in_degree[consumer] < 0:
raise ValueError('consumer:%s degree mismatch'%consumer.name)
left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
if left_ops:
return (False, left_ops)
else:
assert len(g.get_operations()) == len(sorted_ops)
return (True, sorted_ops)
@staticmethod
def _make_op_and_tensor_maps(op_list):
"""Creates various maps and lists from op_list.
Args:
op_list: a list of Ops
Returns:
opname_idx_map: a map from Op's name to its index in op_list.
tensor_list: a list of output tensors of the Ops in op_list.
tensorname_idx_map: a map from output tensor name to its index
in tensor_list.
"""
opname_idx_map = {}
tensor_list = []
tensorname_idx_map = {}
for op_id, op in enumerate(op_list):
if op.name in opname_idx_map:
raise ValueError('Duplicated Op name: %s'%op.name)
opname_idx_map[op.name] = op_id
for output_tensor in op.outputs:
if output_tensor.name not in tensorname_idx_map:
tensor_list.append(output_tensor)
tensorname_idx_map[output_tensor.name] = len(tensor_list)-1
return (opname_idx_map, tensor_list, tensorname_idx_map)
def __init__(self):
"""Initializes a TensorTracer.
Sets the various member fields from the flags (if given) or the defaults.
"""
self._parameters = tensor_tracer_flags.TTParameters()
self._set_report_file()
self._version = 'use-outside-compilation'
self._device_type = None
self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE
self._instrument_records = {}
self._num_replicas = None
self._num_replicas_per_host = None
self._num_hosts = None
self._replica_id = None
self._tt_config = tensor_tracer_report.TensorTracerConfig()
self._parameters = tensor_tracer_flags.TTParameters()
self._included_op_full_names = set()
def _add_replica_id_to_graph(self):
"""Adds nodes for computing the replica ID to the graph."""
if self._num_replicas:
if self._tt_config.num_replicas:
with ops.control_dependencies(None):
# Uses None as dependency to run outside of TPU graph rewrites.
self._replica_id = tpu_ops.tpu_replicated_input(
list(range(self._num_replicas)),
list(range(self._tt_config.num_replicas)),
name='tt_replica_id')
else:
self._replica_id = 'unknown'
def _set_report_file(self):
"""Sets the path of the output report file."""
if not self._parameters.report_file_path:
self._report_file = None
return
try:
self._report_file = gfile.Open(self._parameters.report_file_path, 'w')
except IOError as e:
raise e
def _close_report_file(self):
if self._report_file:
self._report_file.close()
def _inside_op_range(self, idx):
"""Return True if the given index is inside the selected range."""
@ -519,107 +388,6 @@ class TensorTracer(object):
indices = constant_op.constant([cache_idx])
return state_ops.scatter_update(cache, indices, updates).op
def _write_report(self, content):
"""Writes the given content to the report."""
line = '%s %s'%(_TRACER_LOG_PREFIX, content)
if self._report_file:
self._report_file.write(line)
else:
logging.info(line)
def _write_config_section(self):
"""Writes the config section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version))
self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type))
self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE,
self._parameters.trace_mode))
self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE,
self._parameters.submode))
if self._parameters.included_cores:
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
len(self._parameters.included_cores)))
else:
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
self._num_replicas))
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
self._num_replicas_per_host))
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, self._num_hosts))
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
def _write_reason_section(self):
"""Writes the reason section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON))
for key in sorted(self._instrument_records):
self._write_report('"%s" %s\n'%(key, self._instrument_records[key]))
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
def _write_op_list_section(self, op_list):
"""Writes the Op-list section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list)))
for i in range(0, len(op_list)):
op = op_list[i]
line = '%d "%s" %s'%(i, op.name, op.type)
for out_tensor in op.outputs:
if out_tensor.name not in self._tensorname_idx_map:
raise ValueError(
'out_tensor %s is not in tensorname_idx_map'%out_tensor.name)
line += ' %d'%self._tensorname_idx_map[out_tensor.name]
line += '\n'
self._write_report(line)
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
def _write_tensor_list_section(self, tensor_list, opname_idx_map):
"""Writes the tensor-list section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
_SECTION_NAME_TENSOR_LIST))
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list)))
for i in range(0, len(tensor_list)):
tensor = tensor_list[i]
line = '%d "%s"'%(i, tensor.name)
consumers = tensor.consumers()
consumers.sort(key=lambda op: op.name)
for consumer_op in consumers:
if consumer_op.name not in opname_idx_map:
raise ValueError(
'consumer_op %s is not in opname_idx_map'%consumer_op.name)
line += ' %d'%opname_idx_map[consumer_op.name]
line += '\n'
self._write_report(line)
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
_SECTION_NAME_TENSOR_LIST))
def _write_cache_index_map_section(self):
"""Writes the mapping from cache index to tensor index to the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
_SECTION_NAME_CACHE_INDEX_MAP))
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_CACHE_INDICES,
len(self._cache_idx_to_tensor_idx)))
for cache_idx in range(0, len(self._cache_idx_to_tensor_idx)):
tensor_idx = self._cache_idx_to_tensor_idx[cache_idx]
line = '%d %d\n'%(cache_idx, tensor_idx)
self._write_report(line)
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
_SECTION_NAME_CACHE_INDEX_MAP))
def _write_graph_section(self, succeed, sorted_or_cycle):
"""Writes the graph section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH))
self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED,
succeed))
l = list(sorted_or_cycle)
for i in range(0, len(l)):
self._write_report('%d "%s"\n'%(i, l[i].name))
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH))
def _preprocess_traced_tensor(self, tensor):
"""Computes NAN/Norm/Max on TPUs before sending to CPU.
@ -699,12 +467,12 @@ class TensorTracer(object):
'Tensor trace fun for %s is not yet implemented'
% self._parameters.trace_mode)
def _make_tensor_trace_fun(self, tensor_name):
def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order):
"""Makes the tensor tracing function called by outside compilation.
Args:
tensor_name: name of the tensor being traced.
tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
Returns:
A function to be passed as the first argument to outside compilation.
@ -730,7 +498,7 @@ class TensorTracer(object):
"""
if self._parameters.is_brief_mode():
if tensor_name not in self._tensorname_idx_map:
if tensor_name not in tensor_trace_order.tensorname_idx_map:
raise ValueError(
'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
msg = '%d'%self._tensorname_idx_map[tensor_name]
@ -751,7 +519,7 @@ class TensorTracer(object):
def _show_part_tensor(tensor):
"""Trace function for printing part of the tensor."""
return _print_tensor(tensor_name, self._part_tensor_size,
return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE,
tensor, tensor)
def _show_full_tensor(tensor):
@ -808,47 +576,44 @@ class TensorTracer(object):
raise RuntimeError('Tensor trace fun for %s is not yet implemented'
%self._parameters.trace_mode)
def _skip_op(self, op_id, op, user_included, user_excluded,
in_exec_path=True):
def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
"""Returns True if we should not trace Op."""
if TensorTracer.while_loop_op(op):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_WHILELOOP_OP)
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
return True
if TensorTracer.unsafe_op(op):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_UNSAFE_OP)
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP))
return True
if TensorTracer.device_mismatch(self._device_type, op):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_DEVICE_MISMATCH)
if TensorTracer.device_mismatch(self._tt_config.device_type, op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH))
return True
if not in_exec_path:
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_NOT_EXECUTED)
if op not in ops_in_exec_path:
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED))
return True
if not self._inside_op_range(op_id):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_OUTSIDE_OP_RANGE)
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE))
return True
if self._less_interesting_op(op):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_LESS_INTERESTING_OP)
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP))
return True
if user_included:
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_USER_INCLUDED)
if self._is_user_included_op(op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
return False
if user_excluded:
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_USER_EXCLUDED)
if self._is_user_excluded_op(op):
report_handler.instrument_op(
op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
return True
return False
def _skip_tensor(self, op_id, out_tensor, user_included,
user_excluded):
def _skip_tensor(self, op_id, out_tensor, report_handler):
"""Returns True if we should not trace out_tensor."""
# Skips a tensor if the tensor has a non-numeric type.
@ -858,22 +623,23 @@ class TensorTracer(object):
non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
dtypes.string])
if out_tensor.dtype in non_numeric_tensor_types:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_NON_NUMERIC_TENSOR)
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR))
return True
# Skip a tensor if it feeds a special while loop op.
if [consumer for consumer in out_tensor.consumers() if
TensorTracer.while_loop_op(consumer)]:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_FEEDS_WHILELOOP_OP)
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP))
return True
if user_included:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_USER_INCLUDED)
if self._is_user_included_op(out_tensor.op):
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
return False
if user_excluded:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_USER_EXCLUDED)
if self._is_user_excluded_op(out_tensor.op):
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
return True
if not out_tensor.get_shape().is_fully_defined():
# If trace mode is nan-inf, norm or max, then the tensor will be reduced
@ -883,33 +649,33 @@ class TensorTracer(object):
tensor_tracer_flags.TRACE_MODE_NORM,
tensor_tracer_flags.TRACE_MODE_MAX_ABS
]:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_TENSOR_GET_TRACED)
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
return False
else:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_DYNAMIC_SHAPE)
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE))
return True
rank = len(out_tensor.shape)
if rank < 1:
# scalar
if self._parameters.trace_scalar_ops:
if TensorTracer.unsafe_scalar_trace(out_tensor.op):
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_UNSAFE_SCALAR)
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR))
return True
else:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_SCALAR_GET_TRACED)
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED))
return False
else:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_SKIP_SCALAR)
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR))
return True
else:
# tensor
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_TENSOR_GET_TRACED)
report_handler.instrument_tensor(
out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
return False
def _filter_execution_path_operations(self, operations, fetches):
@ -951,42 +717,25 @@ class TensorTracer(object):
traverse_stack.append(input_op)
return execution_path_operations
def _determine_traced_tensors(self, graph, ops_in_exec_path):
"""Determines the tensors that will be traced."""
def _determine_and_instrument_traced_tensors(self, graph_order,
ops_in_exec_path,
tensor_trace_points,
report_handler):
"""Determines the tensors to trace and instruments the trace details."""
self._traced_tensorname_to_cache_idx_map = {}
self._cache_idx_to_tensor_idx = []
operations = graph.get_operations()
checkpoint_operations = self._get_checkpoints(graph)
for op_id, op in enumerate(operations):
if checkpoint_operations and op.name not in checkpoint_operations:
traced_tensors = []
checkpoint_operations = set([tensor.op
for (tensor, _) in tensor_trace_points])
for op_id, op in enumerate(graph_order.operations):
if checkpoint_operations and op not in checkpoint_operations:
continue
user_included = self._is_user_included_op(op)
user_excluded = self._is_user_excluded_op(op)
in_exec_path = op in ops_in_exec_path
if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path):
if self._skip_op(op_id, op, ops_in_exec_path, report_handler):
continue
for i in range(len(op.outputs)):
out_tensor = op.outputs[i]
if self._skip_tensor(op_id, out_tensor, user_included,
user_excluded):
continue
tensor_name = out_tensor.name
if tensor_name in self._traced_tensorname_to_cache_idx_map:
raise ValueError(
'Tensor name %s should not be already in '
'traced_tensorname_to_cache_idx_map'%tensor_name)
if tensor_name not in self._tensorname_idx_map:
raise ValueError(
'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
tensor_idx = self._tensorname_idx_map[tensor_name]
cache_idx = len(self._traced_tensorname_to_cache_idx_map)
self._traced_tensorname_to_cache_idx_map[tensor_name] = cache_idx
self._cache_idx_to_tensor_idx.append(tensor_idx)
if len(self._traced_tensorname_to_cache_idx_map) != len(
self._cache_idx_to_tensor_idx):
raise RuntimeError('len(self._traced_tensorname_to_cache_idx_map) != '
'len(self._cache_idx_to_tensor_idx')
if not self._skip_tensor(op_id, out_tensor, report_handler):
traced_tensors.append(out_tensor)
return traced_tensors
def _check_trace_files(self):
"""Checks if any requirements for trace files are satisfied."""
@ -995,7 +744,7 @@ class TensorTracer(object):
# traces will be written to stderr. No need to check trace files.
return
if _trace_files_need_precreated(self._parameters.trace_dir):
for replica_id in range(0, self._num_replicas):
for replica_id in range(0, self._tt_config.num_replicas):
trace_file_path = os.path.join(
self._parameters.trace_dir,
_COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id
@ -1009,56 +758,25 @@ class TensorTracer(object):
if not gfile.Exists(self._parameters.trace_dir):
raise RuntimeError('Failed to create %s'%self._parameters.trace_dir)
def _pre_tracing(self, graph, fetches):
def _determine_trace_and_create_report(self, graph, ops_in_exec_path):
"""Work needs to be done prior to TPU or CPU tracing."""
self._check_trace_files()
operations = graph.get_operations()
(opname_idx_map, tensor_list, self._tensorname_idx_map) = (
TensorTracer._make_op_and_tensor_maps(operations))
self._write_config_section()
self._write_op_list_section(operations)
self._write_tensor_list_section(tensor_list, opname_idx_map)
# Filter out the operations that won't be executed.
# if fetches=None, then ops_in_exec_path = set(operations)
ops_in_exec_path = self._filter_execution_path_operations(operations,
fetches)
self._determine_traced_tensors(graph, ops_in_exec_path)
self._write_cache_index_map_section()
# Does the topological sort before adding any nodes to the graph.
(succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
graph_order = tensor_tracer_report.sort_tensors_and_ops(graph)
tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION)
report_handler = tensor_tracer_report.TTReportHandle()
traced_tensors = self._determine_and_instrument_traced_tensors(
graph_order, ops_in_exec_path, tensor_trace_points, report_handler)
tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
traced_tensors)
if self._use_tensor_values_cache():
_create_tensor_values_cache(graph,
len(self._cache_idx_to_tensor_idx))
return (ops_in_exec_path, succeed, sorted_or_cycle)
def _post_tracing(self, succeed, sorted_or_cycle):
"""Work needs to be done after TPU or CPU tracing."""
self._write_reason_section()
self._write_graph_section(succeed, sorted_or_cycle)
self._close_report_file()
def _get_checkpoints(self, graph):
"""Returns the list of Ops that produce the tensors traced with API.
Args:
graph: the graph of Ops.
Returns:
A set of operation names which should be traced.
"""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
_TENSOR_TRACER_CHECKPOINT))
checkpoint_operations = set()
tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION)
for (tensor, checkpoint_name) in tensor_tracer_variables:
self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
checkpoint_operations.add(tensor.op.name)
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
_TENSOR_TRACER_CHECKPOINT))
return checkpoint_operations
_create_tensor_values_cache(graph, len(traced_tensors))
report_handler.create_report(self._tt_config, self._parameters,
tensor_trace_order, tensor_trace_points)
return tensor_trace_order
def _generate_flush_cache_op(self, graph, start_replica, on_tpu):
"""Generates an Op that will flush the cache to file.
@ -1085,9 +803,9 @@ class TensorTracer(object):
else:
replica_id_str = '%d'%replica_id
if self._parameters.trace_dir:
output_path = os.path.join(self._parameters.trace_dir,
_COMPACT_TRACE_FILE_PREFIX) \
+ replica_id_str
output_path = (os.path.join(self._parameters.trace_dir,
_COMPACT_TRACE_FILE_PREFIX)
+ replica_id_str)
output_stream = _OUTPUT_STREAM_ESCAPE + output_path
else:
output_stream = sys.stderr
@ -1153,7 +871,7 @@ class TensorTracer(object):
with ops.control_dependencies(op_fetches +
[tensor.op for tensor in tensor_fetches]):
flush_cache_op_list = []
for host in range(self._num_hosts):
for host in range(self._tt_config.num_hosts):
start_replica = host * 8
flush_op = self._generate_flush_cache_op(graph, start_replica, on_tpu)
flush_cache_op_list.append(flush_op)
@ -1233,8 +951,8 @@ class TensorTracer(object):
on_tpu=True):
"""Commong tracing function for both CPU and TPUs.
The caller function should set _device_type, _num_replicas,
_num_replicas_per_host, _num_hosts and _replica_id before calling
The caller function should set device_type, num_replicas,
num_replicas_per_host, num_hosts and replica_id before calling
_trace_execution.
@ -1266,15 +984,19 @@ class TensorTracer(object):
return math_ops.cast(tensor, dtypes.float32)
return tensor
TensorTracer.check_device_type(self._device_type)
TensorTracer.check_device_type(self._tt_config.device_type)
# Check in_tensor_fetches, and op_fetches and convert them to lists.
processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
op_fetches = self._process_op_fetches(op_fetches)
all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]
# Filter the set of ops that will be executed, and topological sort.
(exec_op_set, succeed, sorted_or_cycle) = self._pre_tracing(graph,
all_fetches)
# Filter out the operations that won't be executed.
# if fetches=None, then ops_in_exec_path = set(operations)
exec_op_set = self._filter_execution_path_operations(graph.get_operations(),
all_fetches)
# Write report file, and determine the traced tensors.
tensor_trace_order = self._determine_trace_and_create_report(
graph, exec_op_set)
tensor_fetch_set = set(processed_t_fetches)
tracing_ops = []
@ -1290,7 +1012,7 @@ class TensorTracer(object):
for i in range(len(op.outputs)):
out_tensor = op.outputs[i]
tensor_name = out_tensor.name
if tensor_name not in self._traced_tensorname_to_cache_idx_map:
if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
continue
# Create the list of consumers before calling _preprocess_traced_tensor.
# Otherwise, adding control input below, will introduce a cycle in the
@ -1316,7 +1038,7 @@ class TensorTracer(object):
processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor)
if self._use_tensor_values_cache():
cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name]
cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
trace_op = self._save_tensor_value_to_cache_op(graph,
cache_idx,
processed_out_tensor)
@ -1324,7 +1046,8 @@ class TensorTracer(object):
def tpu_wrap_trace_fn(tensor, out_tensor_name):
"""Wraps the trace_fn with outside compilation if on TPUs."""
tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name)
tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name,
tensor_trace_order)
if on_tpu:
return tpu.outside_compilation(tensor_trace_fn, tensor)
else:
@ -1374,7 +1097,6 @@ class TensorTracer(object):
processed_t_fetches,
op_fetches,
on_tpu=on_tpu)
self._post_tracing(succeed, sorted_or_cycle)
# processed_t_fetches is a list at this point. Convert it to the same
# format as given in tensor_fetches.
return self._convert_fetches_to_input_format(tensor_fetches,
@ -1414,21 +1136,23 @@ class TensorTracer(object):
return tensor_fetches
else:
TensorTracer._traced_graphs.add(graph)
self._device_type = _DEVICE_TYPE_TPU
self._num_replicas = num_replicas
self._num_replicas_per_host = num_replicas_per_host
self._num_hosts = num_hosts
if self._num_replicas is not None:
if self._num_replicas_per_host is None:
self._num_replicas_per_host = 8
if self._num_hosts is None:
self._num_hosts = num_replicas // self._num_replicas_per_host + \
(num_replicas % self._num_replicas_per_host > 0)
if self._num_replicas_per_host > 8:
self._tt_config.device_type = _DEVICE_TYPE_TPU
self._tt_config.num_replicas = num_replicas
self._tt_config.num_replicas_per_host = num_replicas_per_host
self._tt_config.num_hosts = num_hosts
if self._tt_config.num_replicas is not None:
if self._tt_config.num_replicas_per_host is None:
self._tt_config.num_replicas_per_host = 8
if self._tt_config.num_hosts is None:
self._tt_config.num_hosts = (
num_replicas // self._tt_config.num_replicas_per_host +
(num_replicas % self._tt_config.num_replicas_per_host > 0))
if self._tt_config.num_replicas_per_host > 8:
# Checks for the assumption in _generate_flush_cache_op().
raise RuntimeError('num_replicas_per_host (%d) is '
'greater than 8'%self._num_replicas_per_host)
'greater than 8'%self._tt_config.num_replicas_per_host)
if self._parameters.graph_dump_path:
graph_io.write_graph(graph, self._parameters.graph_dump_path,
'graph_before_tt.pbtxt')
@ -1467,10 +1191,10 @@ class TensorTracer(object):
else:
TensorTracer._traced_graphs.add(graph)
self._device_type = _DEVICE_TYPE_CPU
self._num_replicas = 1
self._num_replicas_per_host = 1
self._num_hosts = 1
self._tt_config.device_type = _DEVICE_TYPE_CPU
self._tt_config.num_replicas = 1
self._tt_config.num_replicas_per_host = 1
self._tt_config.num_hosts = 1
self._replica_id = 0
if self._parameters.graph_dump_path:
graph_io.write_graph(graph, self._parameters.graph_dump_path,

View File

@ -0,0 +1,341 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========================================================================
"""Tensor Tracer report generation utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
_MARKER_SECTION_END = '!!!!!!! section-end:'
_SECTION_NAME_CONFIG = 'configuration'
_SECTION_NAME_REASON = 'reason'
_SECTION_NAME_OP_LIST = 'op-list'
_SECTION_NAME_TENSOR_LIST = 'tensor-list'
_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
_SECTION_NAME_GRAPH = 'graph'
_SECTION_NAME_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
_FIELD_NAME_VERSION = 'version:'
_FIELD_NAME_DEVICE = 'device:'
_FIELD_NAME_TRACE_MODE = 'trace-mode:'
_FIELD_NAME_SUBMODE = 'submode:'
_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
_FIELD_NAME_NUM_OPS = 'number-of-ops:'
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
_CURRENT_VERSION = 'use-outside-compilation'
def topological_sort(g):
"""Performs topological sort on the given graph.
Args:
g: the graph.
Returns:
A pair where the first element indicates if the topological
sort succeeded (True if there is no cycle found; False if a
cycle is found) and the second element is either the sorted
list of nodes or the cycle of nodes found.
"""
def _is_loop_edge(op):
"""Returns true if the op is the end of a while-loop creating a cycle."""
return op.type in ['NextIteration']
def _in_op_degree(op):
"""Returns the number of incoming edges to the given op.
The edge calculation skips the edges that come from 'NextIteration' ops.
NextIteration creates a cycle in the graph. We break cycles by treating
this op as 'sink' and ignoring all outgoing edges from it.
Args:
op: Tf.Operation
Returns:
the number of incoming edges.
"""
count = 0
for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
if not _is_loop_edge(op):
count += 1
return count
sorted_ops = []
op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
frontier.sort(key=lambda op: op.name)
while frontier:
op = frontier.pop()
# Remove the op from graph, and remove its outgoing edges.
sorted_ops.append(op)
if _is_loop_edge(op):
continue
# pylint: disable=protected-access
consumers = list(op._control_outputs)
# pylint: enable=protected-access
for out_tensor in op.outputs:
consumers += [consumer_op for consumer_op in out_tensor.consumers()]
consumers.sort(key=lambda op: op.name)
for consumer in consumers:
# For each deleted edge shift the bucket of the vertex.
op_in_degree[consumer] -= 1
if op_in_degree[consumer] == 0:
frontier.append(consumer)
if op_in_degree[consumer] < 0:
raise ValueError('consumer:%s degree mismatch'%consumer.name)
left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
if left_ops:
return (True, left_ops)
else:
assert len(g.get_operations()) == len(sorted_ops)
return (False, sorted_ops)
class TensorTracerConfig(object):
"""Tensor Tracer config object."""
def __init__(self):
self.version = _CURRENT_VERSION
self.device_type = None
self.num_replicas = None
self.num_replicas_per_host = None
self.num_hosts = None
class TensorTraceOrder(object):
"""Class that is responsible from storing the trace-id of the tensors."""
def __init__(self, graph_order, traced_tensors):
self.graph_order = graph_order
self.traced_tensors = traced_tensors
self._create_tensor_maps()
def _create_tensor_maps(self):
"""Creates tensor to cache id maps."""
self.tensorname_to_cache_idx = {}
self.cache_idx_to_tensor_idx = []
for out_tensor in self.traced_tensors:
tensor_name = out_tensor.name
if tensor_name in self.tensorname_to_cache_idx:
raise ValueError(
'Tensor name %s should not be already in '
'tensorname_to_cache_idx'%tensor_name)
if tensor_name not in self.graph_order.tensor_to_idx:
raise ValueError(
'Tensor name %s is not in the tensor_to_idx'%tensor_name)
tensor_idx = self.graph_order.tensor_to_idx[tensor_name]
cache_idx = len(self.tensorname_to_cache_idx)
self.tensorname_to_cache_idx[tensor_name] = cache_idx
self.cache_idx_to_tensor_idx.append(tensor_idx)
if len(self.tensorname_to_cache_idx) != len(
self.cache_idx_to_tensor_idx):
raise RuntimeError('len(self.tensorname_to_cache_idx) != '
'len(self.cache_idx_to_tensor_idx')
def sort_tensors_and_ops(graph):
"""Returns a wrapper that has consistent tensor and op orders."""
graph_wrapper = collections.namedtuple('GraphWrapper',
['graph', 'operations', 'op_to_idx',
'tensors', 'tensor_to_idx',
'contains_cycle',
'topological_order_or_cycle'])
operations = graph.get_operations()
op_to_idx = {op.name: index for index, op
in enumerate(operations)}
tensors = []
for op in operations:
tensors.extend(op.outputs)
tensor_to_idx = {tensor.name: index for index, tensor in
enumerate(tensors)}
contains_cycle, topological_order_or_cycle = topological_sort(graph)
return graph_wrapper(graph=graph, operations=operations, op_to_idx=op_to_idx,
tensors=tensors, tensor_to_idx=tensor_to_idx,
contains_cycle=contains_cycle,
topological_order_or_cycle=topological_order_or_cycle)
class OpenReportFile(object):
"""Context manager for writing report file."""
def __init__(self, tt_parameters):
if not tt_parameters.report_file_path:
self._report_file = None
return
try:
self._report_file = gfile.Open(tt_parameters.report_file_path, 'w')
except IOError as e:
raise e
def __enter__(self):
return self._report_file
def __exit__(self, unused_type, unused_value, unused_traceback):
if self._report_file:
self._report_file.close()
class TTReportHandle(object):
"""Utility class responsible from creating a tensor tracer report."""
def __init__(self):
self.instrument_records = {}
self._report_file = None
def instrument(self, name, explanation):
self.instrument_records[name] = explanation
def instrument_op(self, op, explanation):
self.instrument(op.name, explanation)
def instrument_tensor(self, tensor, explanation):
self.instrument(tensor.name, explanation)
def create_report(self, tt_config, tt_parameters,
tensor_trace_order, tensor_trace_points):
"""Creates a report file and writes the trace information."""
with OpenReportFile(tt_parameters) as self._report_file:
self._write_config_section(tt_config, tt_parameters)
self._write_op_list_section(tensor_trace_order.graph_order)
self._write_tensor_list_section(tensor_trace_order.graph_order)
self._write_trace_points(tensor_trace_points)
self._write_cache_index_map_section(tensor_trace_order)
self._write_reason_section()
self._write_graph_section(tensor_trace_order.graph_order)
def _write_trace_points(self, tensor_trace_points):
"""Writes the list of checkpoints."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
_SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
for (tensor, checkpoint_name) in tensor_trace_points:
self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
_SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
def _write_report(self, content):
"""Writes the given content to the report."""
line = '%s %s'%(_TRACER_LOG_PREFIX, content)
if self._report_file:
self._report_file.write(line)
else:
logging.info(line)
def _write_config_section(self, tt_config, tt_parameters):
"""Writes the config section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, tt_config.version))
self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, tt_config.device_type))
self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE,
tt_parameters.trace_mode))
self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE,
tt_parameters.submode))
if tt_parameters.included_cores:
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
len(tt_parameters.included_cores)))
else:
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
tt_config.num_replicas))
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
tt_config.num_replicas_per_host))
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, tt_config.num_hosts))
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
def _write_reason_section(self):
"""Writes the reason section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON))
for key in sorted(self.instrument_records):
self._write_report('"%s" %s\n'%(key, self.instrument_records[key]))
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
def _write_op_list_section(self, graph_order):
"""Writes the Op-list section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS,
len(graph_order.operations)))
for i in range(0, len(graph_order.operations)):
op = graph_order.operations[i]
line = '%d "%s" %s'%(i, op.name, op.type)
for out_tensor in op.outputs:
if out_tensor.name not in graph_order.tensor_to_idx:
raise ValueError(
'out_tensor %s is not in tensor_to_idx'%out_tensor.name)
line += ' %d'%graph_order.tensor_to_idx[out_tensor.name]
line += '\n'
self._write_report(line)
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
def _write_tensor_list_section(self, graph_order):
"""Writes the tensor-list section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
_SECTION_NAME_TENSOR_LIST))
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS,
len(graph_order.tensors)))
for i in range(0, len(graph_order.tensors)):
tensor = graph_order.tensors[i]
line = '%d "%s"'%(i, tensor.name)
consumers = tensor.consumers()
consumers.sort(key=lambda op: op.name)
for consumer_op in consumers:
if consumer_op.name not in graph_order.op_to_idx:
raise ValueError(
'consumer_op %s is not in op_to_idx'%consumer_op.name)
line += ' %d'%graph_order.op_to_idx[consumer_op.name]
line += '\n'
self._write_report(line)
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
_SECTION_NAME_TENSOR_LIST))
def _write_cache_index_map_section(self, tensor_trace_order):
"""Writes the mapping from cache index to tensor index to the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
_SECTION_NAME_CACHE_INDEX_MAP))
self._write_report('%s %d\n'%(
_FIELD_NAME_NUM_CACHE_INDICES,
len(tensor_trace_order.cache_idx_to_tensor_idx)))
for cache_idx in range(0, len(tensor_trace_order.cache_idx_to_tensor_idx)):
tensor_idx = tensor_trace_order.cache_idx_to_tensor_idx[cache_idx]
line = '%d %d\n'%(cache_idx, tensor_idx)
self._write_report(line)
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
_SECTION_NAME_CACHE_INDEX_MAP))
def _write_graph_section(self, graph_order):
"""Writes the graph section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH))
self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED,
not graph_order.contains_cycle))
l = list(graph_order.topological_order_or_cycle)
for i in range(0, len(l)):
self._write_report('%d "%s"\n'%(i, l[i].name))
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH))