Tensor tracer refactor.
PiperOrigin-RevId: 250316061
This commit is contained in:
		
							parent
							
								
									c47ec54211
								
							
						
					
					
						commit
						53027266f0
					
				@ -155,6 +155,7 @@ py_library(
 | 
			
		||||
        "session_support.py",
 | 
			
		||||
        "tensor_tracer.py",
 | 
			
		||||
        "tensor_tracer_flags.py",
 | 
			
		||||
        "tensor_tracer_report.py",
 | 
			
		||||
        "topology.py",
 | 
			
		||||
        "tpu.py",
 | 
			
		||||
        "tpu_feed.py",
 | 
			
		||||
 | 
			
		||||
@ -40,13 +40,14 @@ from tensorflow.python.ops import variable_scope
 | 
			
		||||
from tensorflow.python.platform import gfile
 | 
			
		||||
from tensorflow.python.platform import tf_logging as logging
 | 
			
		||||
from tensorflow.python.tpu import tensor_tracer_flags
 | 
			
		||||
from tensorflow.python.tpu import tensor_tracer_report
 | 
			
		||||
from tensorflow.python.tpu import tpu
 | 
			
		||||
from tensorflow.python.tpu.ops import tpu_ops
 | 
			
		||||
 | 
			
		||||
_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
 | 
			
		||||
_DEVICE_TYPE_TPU = 'tpu'
 | 
			
		||||
_DEVICE_TYPE_CPU = 'cpu'
 | 
			
		||||
_TRACE_MODE_PART_TENSOR_SIZE = 3
 | 
			
		||||
 | 
			
		||||
_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
 | 
			
		||||
_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
 | 
			
		||||
_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
 | 
			
		||||
@ -62,28 +63,9 @@ _REASON_USER_EXCLUDED = 'not-traced-user-excluded'
 | 
			
		||||
_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
 | 
			
		||||
_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
 | 
			
		||||
_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
 | 
			
		||||
_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
 | 
			
		||||
_MARKER_SECTION_END = '!!!!!!! section-end:'
 | 
			
		||||
_SECTION_NAME_CONFIG = 'configuration'
 | 
			
		||||
_SECTION_NAME_REASON = 'reason'
 | 
			
		||||
_SECTION_NAME_OP_LIST = 'op-list'
 | 
			
		||||
_SECTION_NAME_TENSOR_LIST = 'tensor-list'
 | 
			
		||||
_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
 | 
			
		||||
_SECTION_NAME_GRAPH = 'graph'
 | 
			
		||||
_FIELD_NAME_VERSION = 'version:'
 | 
			
		||||
_FIELD_NAME_DEVICE = 'device:'
 | 
			
		||||
_FIELD_NAME_TRACE_MODE = 'trace-mode:'
 | 
			
		||||
_FIELD_NAME_SUBMODE = 'submode:'
 | 
			
		||||
_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
 | 
			
		||||
_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
 | 
			
		||||
_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
 | 
			
		||||
_FIELD_NAME_NUM_OPS = 'number-of-ops:'
 | 
			
		||||
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
 | 
			
		||||
_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
 | 
			
		||||
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
 | 
			
		||||
 | 
			
		||||
_OUTPUT_STREAM_ESCAPE = 'file://'
 | 
			
		||||
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
 | 
			
		||||
_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
 | 
			
		||||
_TRACE_FILE_NAME = 'trace.all'
 | 
			
		||||
_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
 | 
			
		||||
_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
 | 
			
		||||
@ -191,6 +173,7 @@ def _create_tensor_values_cache(graph, num_tensors):
 | 
			
		||||
        use_resource=True,
 | 
			
		||||
        collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorTracer(object):
 | 
			
		||||
  """A software construct for tracing tensor values in a TF graph on TPU.
 | 
			
		||||
 | 
			
		||||
@ -298,142 +281,28 @@ class TensorTracer(object):
 | 
			
		||||
 | 
			
		||||
    return '%d %s'%(op_idx, details)
 | 
			
		||||
 | 
			
		||||
  @staticmethod
 | 
			
		||||
  def topological_sort(g):
 | 
			
		||||
    """Performs topological sort on the given graph.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
       g: the graph.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
       A pair where the first element indicates if the topological
 | 
			
		||||
       sort succeeded (True if there is no cycle found; False if a
 | 
			
		||||
       cycle is found) and the second element is either the sorted
 | 
			
		||||
       list of nodes or the cycle of nodes found.
 | 
			
		||||
    """
 | 
			
		||||
    def _is_loop_edge(op):
 | 
			
		||||
      """Returns true if the op is the end of a while-loop creating a cycle."""
 | 
			
		||||
      return op.type in ['NextIteration']
 | 
			
		||||
 | 
			
		||||
    def _in_op_degree(op):
 | 
			
		||||
      """Returns the number of incoming edges to the given op.
 | 
			
		||||
 | 
			
		||||
      The edge calculation skips the edges that come from 'NextIteration' ops.
 | 
			
		||||
      NextIteration creates a cycle in the graph. We break cycles by treating
 | 
			
		||||
      this op as 'sink' and ignoring all outgoing edges from it.
 | 
			
		||||
      Args:
 | 
			
		||||
        op: Tf.Operation
 | 
			
		||||
      Returns:
 | 
			
		||||
        the number of incoming edges.
 | 
			
		||||
      """
 | 
			
		||||
      count = 0
 | 
			
		||||
      for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
 | 
			
		||||
        if not _is_loop_edge(op):
 | 
			
		||||
          count += 1
 | 
			
		||||
      return count
 | 
			
		||||
 | 
			
		||||
    sorted_ops = []
 | 
			
		||||
    op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
 | 
			
		||||
 | 
			
		||||
    frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
 | 
			
		||||
    frontier.sort(key=lambda op: op.name)
 | 
			
		||||
    while frontier:
 | 
			
		||||
      op = frontier.pop()
 | 
			
		||||
      # Remove the op from graph, and remove its outgoing edges.
 | 
			
		||||
      sorted_ops.append(op)
 | 
			
		||||
      if _is_loop_edge(op):
 | 
			
		||||
        continue
 | 
			
		||||
      # pylint: disable=protected-access
 | 
			
		||||
      consumers = list(op._control_outputs)
 | 
			
		||||
      # pylint: enable=protected-access
 | 
			
		||||
      for out_tensor in op.outputs:
 | 
			
		||||
        consumers += [consumer_op for consumer_op in out_tensor.consumers()]
 | 
			
		||||
      consumers.sort(key=lambda op: op.name)
 | 
			
		||||
      for consumer in consumers:
 | 
			
		||||
        # For each deleted edge shift the bucket of the vertex.
 | 
			
		||||
        op_in_degree[consumer] -= 1
 | 
			
		||||
        if op_in_degree[consumer] == 0:
 | 
			
		||||
          frontier.append(consumer)
 | 
			
		||||
        if op_in_degree[consumer] < 0:
 | 
			
		||||
          raise ValueError('consumer:%s degree mismatch'%consumer.name)
 | 
			
		||||
 | 
			
		||||
    left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
 | 
			
		||||
    if left_ops:
 | 
			
		||||
      return (False, left_ops)
 | 
			
		||||
    else:
 | 
			
		||||
      assert len(g.get_operations()) == len(sorted_ops)
 | 
			
		||||
      return (True, sorted_ops)
 | 
			
		||||
 | 
			
		||||
  @staticmethod
 | 
			
		||||
  def _make_op_and_tensor_maps(op_list):
 | 
			
		||||
    """Creates various maps and lists from op_list.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
       op_list: a list of Ops
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
       opname_idx_map: a map from Op's name to its index in op_list.
 | 
			
		||||
       tensor_list: a list of output tensors of the Ops in op_list.
 | 
			
		||||
       tensorname_idx_map: a map from output tensor name to its index
 | 
			
		||||
                           in tensor_list.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    opname_idx_map = {}
 | 
			
		||||
    tensor_list = []
 | 
			
		||||
    tensorname_idx_map = {}
 | 
			
		||||
    for op_id, op in enumerate(op_list):
 | 
			
		||||
      if op.name in opname_idx_map:
 | 
			
		||||
        raise ValueError('Duplicated Op name: %s'%op.name)
 | 
			
		||||
      opname_idx_map[op.name] = op_id
 | 
			
		||||
      for output_tensor in op.outputs:
 | 
			
		||||
        if output_tensor.name not in tensorname_idx_map:
 | 
			
		||||
          tensor_list.append(output_tensor)
 | 
			
		||||
          tensorname_idx_map[output_tensor.name] = len(tensor_list)-1
 | 
			
		||||
    return (opname_idx_map, tensor_list, tensorname_idx_map)
 | 
			
		||||
 | 
			
		||||
  def __init__(self):
 | 
			
		||||
    """Initializes a TensorTracer.
 | 
			
		||||
 | 
			
		||||
    Sets the various member fields from the flags (if given) or the defaults.
 | 
			
		||||
    """
 | 
			
		||||
    self._parameters = tensor_tracer_flags.TTParameters()
 | 
			
		||||
    self._set_report_file()
 | 
			
		||||
    self._version = 'use-outside-compilation'
 | 
			
		||||
    self._device_type = None
 | 
			
		||||
    self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE
 | 
			
		||||
    self._instrument_records = {}
 | 
			
		||||
    self._num_replicas = None
 | 
			
		||||
    self._num_replicas_per_host = None
 | 
			
		||||
    self._num_hosts = None
 | 
			
		||||
    self._replica_id = None
 | 
			
		||||
    self._tt_config = tensor_tracer_report.TensorTracerConfig()
 | 
			
		||||
    self._parameters = tensor_tracer_flags.TTParameters()
 | 
			
		||||
    self._included_op_full_names = set()
 | 
			
		||||
 | 
			
		||||
  def _add_replica_id_to_graph(self):
 | 
			
		||||
    """Adds nodes for computing the replica ID to the graph."""
 | 
			
		||||
 | 
			
		||||
    if self._num_replicas:
 | 
			
		||||
    if self._tt_config.num_replicas:
 | 
			
		||||
      with ops.control_dependencies(None):
 | 
			
		||||
        # Uses None as dependency to run outside of TPU graph rewrites.
 | 
			
		||||
        self._replica_id = tpu_ops.tpu_replicated_input(
 | 
			
		||||
            list(range(self._num_replicas)),
 | 
			
		||||
            list(range(self._tt_config.num_replicas)),
 | 
			
		||||
            name='tt_replica_id')
 | 
			
		||||
    else:
 | 
			
		||||
      self._replica_id = 'unknown'
 | 
			
		||||
 | 
			
		||||
  def _set_report_file(self):
 | 
			
		||||
    """Sets the path of the output report file."""
 | 
			
		||||
    if not self._parameters.report_file_path:
 | 
			
		||||
      self._report_file = None
 | 
			
		||||
      return
 | 
			
		||||
    try:
 | 
			
		||||
      self._report_file = gfile.Open(self._parameters.report_file_path, 'w')
 | 
			
		||||
    except IOError as e:
 | 
			
		||||
      raise e
 | 
			
		||||
 | 
			
		||||
  def _close_report_file(self):
 | 
			
		||||
    if self._report_file:
 | 
			
		||||
      self._report_file.close()
 | 
			
		||||
 | 
			
		||||
  def _inside_op_range(self, idx):
 | 
			
		||||
    """Return True if the given index is inside the selected range."""
 | 
			
		||||
 | 
			
		||||
@ -519,107 +388,6 @@ class TensorTracer(object):
 | 
			
		||||
    indices = constant_op.constant([cache_idx])
 | 
			
		||||
    return state_ops.scatter_update(cache, indices, updates).op
 | 
			
		||||
 | 
			
		||||
  def _write_report(self, content):
 | 
			
		||||
    """Writes the given content to the report."""
 | 
			
		||||
 | 
			
		||||
    line = '%s %s'%(_TRACER_LOG_PREFIX, content)
 | 
			
		||||
    if self._report_file:
 | 
			
		||||
      self._report_file.write(line)
 | 
			
		||||
    else:
 | 
			
		||||
      logging.info(line)
 | 
			
		||||
 | 
			
		||||
  def _write_config_section(self):
 | 
			
		||||
    """Writes the config section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE,
 | 
			
		||||
                                  self._parameters.trace_mode))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE,
 | 
			
		||||
                                  self._parameters.submode))
 | 
			
		||||
    if self._parameters.included_cores:
 | 
			
		||||
      self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
 | 
			
		||||
                                    len(self._parameters.included_cores)))
 | 
			
		||||
    else:
 | 
			
		||||
      self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
 | 
			
		||||
                                    self._num_replicas))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
 | 
			
		||||
                                  self._num_replicas_per_host))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, self._num_hosts))
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
 | 
			
		||||
 | 
			
		||||
  def _write_reason_section(self):
 | 
			
		||||
    """Writes the reason section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON))
 | 
			
		||||
    for key in sorted(self._instrument_records):
 | 
			
		||||
      self._write_report('"%s" %s\n'%(key, self._instrument_records[key]))
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
 | 
			
		||||
 | 
			
		||||
  def _write_op_list_section(self, op_list):
 | 
			
		||||
    """Writes the Op-list section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
 | 
			
		||||
    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list)))
 | 
			
		||||
    for i in range(0, len(op_list)):
 | 
			
		||||
      op = op_list[i]
 | 
			
		||||
      line = '%d "%s" %s'%(i, op.name, op.type)
 | 
			
		||||
      for out_tensor in op.outputs:
 | 
			
		||||
        if out_tensor.name not in self._tensorname_idx_map:
 | 
			
		||||
          raise ValueError(
 | 
			
		||||
              'out_tensor %s is not in tensorname_idx_map'%out_tensor.name)
 | 
			
		||||
        line += ' %d'%self._tensorname_idx_map[out_tensor.name]
 | 
			
		||||
      line += '\n'
 | 
			
		||||
      self._write_report(line)
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
 | 
			
		||||
 | 
			
		||||
  def _write_tensor_list_section(self, tensor_list, opname_idx_map):
 | 
			
		||||
    """Writes the tensor-list section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
 | 
			
		||||
                                  _SECTION_NAME_TENSOR_LIST))
 | 
			
		||||
    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list)))
 | 
			
		||||
    for i in range(0, len(tensor_list)):
 | 
			
		||||
      tensor = tensor_list[i]
 | 
			
		||||
      line = '%d "%s"'%(i, tensor.name)
 | 
			
		||||
      consumers = tensor.consumers()
 | 
			
		||||
      consumers.sort(key=lambda op: op.name)
 | 
			
		||||
      for consumer_op in consumers:
 | 
			
		||||
        if consumer_op.name not in opname_idx_map:
 | 
			
		||||
          raise ValueError(
 | 
			
		||||
              'consumer_op %s is not in opname_idx_map'%consumer_op.name)
 | 
			
		||||
        line += ' %d'%opname_idx_map[consumer_op.name]
 | 
			
		||||
      line += '\n'
 | 
			
		||||
      self._write_report(line)
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
 | 
			
		||||
                                  _SECTION_NAME_TENSOR_LIST))
 | 
			
		||||
 | 
			
		||||
  def _write_cache_index_map_section(self):
 | 
			
		||||
    """Writes the mapping from cache index to tensor index to the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
 | 
			
		||||
                                  _SECTION_NAME_CACHE_INDEX_MAP))
 | 
			
		||||
    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_CACHE_INDICES,
 | 
			
		||||
                                  len(self._cache_idx_to_tensor_idx)))
 | 
			
		||||
    for cache_idx in range(0, len(self._cache_idx_to_tensor_idx)):
 | 
			
		||||
      tensor_idx = self._cache_idx_to_tensor_idx[cache_idx]
 | 
			
		||||
      line = '%d %d\n'%(cache_idx, tensor_idx)
 | 
			
		||||
      self._write_report(line)
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
 | 
			
		||||
                                  _SECTION_NAME_CACHE_INDEX_MAP))
 | 
			
		||||
 | 
			
		||||
  def _write_graph_section(self, succeed, sorted_or_cycle):
 | 
			
		||||
    """Writes the graph section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED,
 | 
			
		||||
                                  succeed))
 | 
			
		||||
    l = list(sorted_or_cycle)
 | 
			
		||||
    for i in range(0, len(l)):
 | 
			
		||||
      self._write_report('%d "%s"\n'%(i, l[i].name))
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH))
 | 
			
		||||
 | 
			
		||||
  def _preprocess_traced_tensor(self, tensor):
 | 
			
		||||
    """Computes NAN/Norm/Max on TPUs before sending to CPU.
 | 
			
		||||
 | 
			
		||||
@ -699,12 +467,12 @@ class TensorTracer(object):
 | 
			
		||||
        'Tensor trace fun for %s is not yet implemented'
 | 
			
		||||
        % self._parameters.trace_mode)
 | 
			
		||||
 | 
			
		||||
  def _make_tensor_trace_fun(self, tensor_name):
 | 
			
		||||
  def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order):
 | 
			
		||||
    """Makes the tensor tracing function called by outside compilation.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      tensor_name: name of the tensor being traced.
 | 
			
		||||
 | 
			
		||||
      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
 | 
			
		||||
    Returns:
 | 
			
		||||
      A function to be passed as the first argument to outside compilation.
 | 
			
		||||
 | 
			
		||||
@ -730,7 +498,7 @@ class TensorTracer(object):
 | 
			
		||||
      """
 | 
			
		||||
 | 
			
		||||
      if self._parameters.is_brief_mode():
 | 
			
		||||
        if tensor_name not in self._tensorname_idx_map:
 | 
			
		||||
        if tensor_name not in tensor_trace_order.tensorname_idx_map:
 | 
			
		||||
          raise ValueError(
 | 
			
		||||
              'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
 | 
			
		||||
        msg = '%d'%self._tensorname_idx_map[tensor_name]
 | 
			
		||||
@ -751,7 +519,7 @@ class TensorTracer(object):
 | 
			
		||||
    def _show_part_tensor(tensor):
 | 
			
		||||
      """Trace function for printing part of the tensor."""
 | 
			
		||||
 | 
			
		||||
      return _print_tensor(tensor_name, self._part_tensor_size,
 | 
			
		||||
      return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE,
 | 
			
		||||
                           tensor, tensor)
 | 
			
		||||
 | 
			
		||||
    def _show_full_tensor(tensor):
 | 
			
		||||
@ -808,47 +576,44 @@ class TensorTracer(object):
 | 
			
		||||
    raise RuntimeError('Tensor trace fun for %s is not yet implemented'
 | 
			
		||||
                       %self._parameters.trace_mode)
 | 
			
		||||
 | 
			
		||||
  def _skip_op(self, op_id, op, user_included, user_excluded,
 | 
			
		||||
               in_exec_path=True):
 | 
			
		||||
  def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
 | 
			
		||||
    """Returns True if we should not trace Op."""
 | 
			
		||||
 | 
			
		||||
    if TensorTracer.while_loop_op(op):
 | 
			
		||||
      self._instrument_records[op.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_WHILELOOP_OP)
 | 
			
		||||
      report_handler.instrument_op(
 | 
			
		||||
          op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
 | 
			
		||||
      return True
 | 
			
		||||
    if TensorTracer.unsafe_op(op):
 | 
			
		||||
      self._instrument_records[op.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_UNSAFE_OP)
 | 
			
		||||
      report_handler.instrument_op(
 | 
			
		||||
          op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP))
 | 
			
		||||
      return True
 | 
			
		||||
    if TensorTracer.device_mismatch(self._device_type, op):
 | 
			
		||||
      self._instrument_records[op.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_DEVICE_MISMATCH)
 | 
			
		||||
    if TensorTracer.device_mismatch(self._tt_config.device_type, op):
 | 
			
		||||
      report_handler.instrument_op(
 | 
			
		||||
          op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH))
 | 
			
		||||
      return True
 | 
			
		||||
    if not in_exec_path:
 | 
			
		||||
      self._instrument_records[op.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_NOT_EXECUTED)
 | 
			
		||||
    if op not in ops_in_exec_path:
 | 
			
		||||
      report_handler.instrument_op(
 | 
			
		||||
          op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED))
 | 
			
		||||
      return True
 | 
			
		||||
 | 
			
		||||
    if not self._inside_op_range(op_id):
 | 
			
		||||
      self._instrument_records[op.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_OUTSIDE_OP_RANGE)
 | 
			
		||||
      report_handler.instrument_op(
 | 
			
		||||
          op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE))
 | 
			
		||||
      return True
 | 
			
		||||
    if self._less_interesting_op(op):
 | 
			
		||||
      self._instrument_records[op.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_LESS_INTERESTING_OP)
 | 
			
		||||
      report_handler.instrument_op(
 | 
			
		||||
          op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP))
 | 
			
		||||
      return True
 | 
			
		||||
    if user_included:
 | 
			
		||||
      self._instrument_records[op.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_USER_INCLUDED)
 | 
			
		||||
    if self._is_user_included_op(op):
 | 
			
		||||
      report_handler.instrument_op(
 | 
			
		||||
          op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
 | 
			
		||||
      return False
 | 
			
		||||
    if user_excluded:
 | 
			
		||||
      self._instrument_records[op.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_USER_EXCLUDED)
 | 
			
		||||
    if self._is_user_excluded_op(op):
 | 
			
		||||
      report_handler.instrument_op(
 | 
			
		||||
          op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
 | 
			
		||||
      return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
  def _skip_tensor(self, op_id, out_tensor, user_included,
 | 
			
		||||
                   user_excluded):
 | 
			
		||||
  def _skip_tensor(self, op_id, out_tensor, report_handler):
 | 
			
		||||
    """Returns True if we should not trace out_tensor."""
 | 
			
		||||
 | 
			
		||||
    # Skips a tensor if the tensor has a non-numeric type.
 | 
			
		||||
@ -858,22 +623,23 @@ class TensorTracer(object):
 | 
			
		||||
    non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
 | 
			
		||||
                                    dtypes.string])
 | 
			
		||||
    if out_tensor.dtype in non_numeric_tensor_types:
 | 
			
		||||
      self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_NON_NUMERIC_TENSOR)
 | 
			
		||||
 | 
			
		||||
      report_handler.instrument_tensor(
 | 
			
		||||
          out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR))
 | 
			
		||||
      return True
 | 
			
		||||
    # Skip a tensor if it feeds a special while loop op.
 | 
			
		||||
    if [consumer for consumer in out_tensor.consumers() if
 | 
			
		||||
        TensorTracer.while_loop_op(consumer)]:
 | 
			
		||||
      self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_FEEDS_WHILELOOP_OP)
 | 
			
		||||
      report_handler.instrument_tensor(
 | 
			
		||||
          out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP))
 | 
			
		||||
      return True
 | 
			
		||||
    if user_included:
 | 
			
		||||
      self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_USER_INCLUDED)
 | 
			
		||||
    if self._is_user_included_op(out_tensor.op):
 | 
			
		||||
      report_handler.instrument_tensor(
 | 
			
		||||
          out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
 | 
			
		||||
      return False
 | 
			
		||||
    if user_excluded:
 | 
			
		||||
      self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_USER_EXCLUDED)
 | 
			
		||||
    if self._is_user_excluded_op(out_tensor.op):
 | 
			
		||||
      report_handler.instrument_tensor(
 | 
			
		||||
          out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
 | 
			
		||||
      return True
 | 
			
		||||
    if not out_tensor.get_shape().is_fully_defined():
 | 
			
		||||
      # If trace mode is nan-inf, norm or max, then the tensor will be reduced
 | 
			
		||||
@ -883,33 +649,33 @@ class TensorTracer(object):
 | 
			
		||||
          tensor_tracer_flags.TRACE_MODE_NORM,
 | 
			
		||||
          tensor_tracer_flags.TRACE_MODE_MAX_ABS
 | 
			
		||||
      ]:
 | 
			
		||||
        self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
            op_id, _REASON_TENSOR_GET_TRACED)
 | 
			
		||||
        report_handler.instrument_tensor(
 | 
			
		||||
            out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
 | 
			
		||||
        return False
 | 
			
		||||
      else:
 | 
			
		||||
        self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
            op_id, _REASON_DYNAMIC_SHAPE)
 | 
			
		||||
        report_handler.instrument_tensor(
 | 
			
		||||
            out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE))
 | 
			
		||||
        return True
 | 
			
		||||
    rank = len(out_tensor.shape)
 | 
			
		||||
    if rank < 1:
 | 
			
		||||
      # scalar
 | 
			
		||||
      if self._parameters.trace_scalar_ops:
 | 
			
		||||
        if TensorTracer.unsafe_scalar_trace(out_tensor.op):
 | 
			
		||||
          self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
              op_id, _REASON_UNSAFE_SCALAR)
 | 
			
		||||
          report_handler.instrument_tensor(
 | 
			
		||||
              out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR))
 | 
			
		||||
          return True
 | 
			
		||||
        else:
 | 
			
		||||
          self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
              op_id, _REASON_SCALAR_GET_TRACED)
 | 
			
		||||
          report_handler.instrument_tensor(
 | 
			
		||||
              out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED))
 | 
			
		||||
          return False
 | 
			
		||||
      else:
 | 
			
		||||
        self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
            op_id, _REASON_SKIP_SCALAR)
 | 
			
		||||
        report_handler.instrument_tensor(
 | 
			
		||||
            out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR))
 | 
			
		||||
        return True
 | 
			
		||||
    else:
 | 
			
		||||
      # tensor
 | 
			
		||||
      self._instrument_records[out_tensor.name] = TensorTracer.reason(
 | 
			
		||||
          op_id, _REASON_TENSOR_GET_TRACED)
 | 
			
		||||
      report_handler.instrument_tensor(
 | 
			
		||||
          out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
 | 
			
		||||
      return False
 | 
			
		||||
 | 
			
		||||
  def _filter_execution_path_operations(self, operations, fetches):
 | 
			
		||||
@ -951,42 +717,25 @@ class TensorTracer(object):
 | 
			
		||||
          traverse_stack.append(input_op)
 | 
			
		||||
    return execution_path_operations
 | 
			
		||||
 | 
			
		||||
  def _determine_traced_tensors(self, graph, ops_in_exec_path):
 | 
			
		||||
    """Determines the tensors that will be traced."""
 | 
			
		||||
  def _determine_and_instrument_traced_tensors(self, graph_order,
 | 
			
		||||
                                               ops_in_exec_path,
 | 
			
		||||
                                               tensor_trace_points,
 | 
			
		||||
                                               report_handler):
 | 
			
		||||
    """Determines the tensors to trace and instruments the trace details."""
 | 
			
		||||
 | 
			
		||||
    self._traced_tensorname_to_cache_idx_map = {}
 | 
			
		||||
    self._cache_idx_to_tensor_idx = []
 | 
			
		||||
    operations = graph.get_operations()
 | 
			
		||||
    checkpoint_operations = self._get_checkpoints(graph)
 | 
			
		||||
    for op_id, op in enumerate(operations):
 | 
			
		||||
      if checkpoint_operations and op.name not in checkpoint_operations:
 | 
			
		||||
    traced_tensors = []
 | 
			
		||||
    checkpoint_operations = set([tensor.op
 | 
			
		||||
                                 for (tensor, _) in tensor_trace_points])
 | 
			
		||||
    for op_id, op in enumerate(graph_order.operations):
 | 
			
		||||
      if checkpoint_operations and op not in checkpoint_operations:
 | 
			
		||||
        continue
 | 
			
		||||
      user_included = self._is_user_included_op(op)
 | 
			
		||||
      user_excluded = self._is_user_excluded_op(op)
 | 
			
		||||
      in_exec_path = op in ops_in_exec_path
 | 
			
		||||
      if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path):
 | 
			
		||||
      if self._skip_op(op_id, op, ops_in_exec_path, report_handler):
 | 
			
		||||
        continue
 | 
			
		||||
      for i in range(len(op.outputs)):
 | 
			
		||||
        out_tensor = op.outputs[i]
 | 
			
		||||
        if self._skip_tensor(op_id, out_tensor, user_included,
 | 
			
		||||
                             user_excluded):
 | 
			
		||||
          continue
 | 
			
		||||
        tensor_name = out_tensor.name
 | 
			
		||||
        if tensor_name in self._traced_tensorname_to_cache_idx_map:
 | 
			
		||||
          raise ValueError(
 | 
			
		||||
              'Tensor name %s should not be already in '
 | 
			
		||||
              'traced_tensorname_to_cache_idx_map'%tensor_name)
 | 
			
		||||
        if tensor_name not in self._tensorname_idx_map:
 | 
			
		||||
          raise ValueError(
 | 
			
		||||
              'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
 | 
			
		||||
        tensor_idx = self._tensorname_idx_map[tensor_name]
 | 
			
		||||
        cache_idx = len(self._traced_tensorname_to_cache_idx_map)
 | 
			
		||||
        self._traced_tensorname_to_cache_idx_map[tensor_name] = cache_idx
 | 
			
		||||
        self._cache_idx_to_tensor_idx.append(tensor_idx)
 | 
			
		||||
        if len(self._traced_tensorname_to_cache_idx_map) != len(
 | 
			
		||||
            self._cache_idx_to_tensor_idx):
 | 
			
		||||
          raise RuntimeError('len(self._traced_tensorname_to_cache_idx_map) != '
 | 
			
		||||
                             'len(self._cache_idx_to_tensor_idx')
 | 
			
		||||
        if not self._skip_tensor(op_id, out_tensor, report_handler):
 | 
			
		||||
          traced_tensors.append(out_tensor)
 | 
			
		||||
    return traced_tensors
 | 
			
		||||
 | 
			
		||||
  def _check_trace_files(self):
 | 
			
		||||
    """Checks if any requirements for trace files are satisfied."""
 | 
			
		||||
@ -995,7 +744,7 @@ class TensorTracer(object):
 | 
			
		||||
      # traces will be written to stderr. No need to check trace files.
 | 
			
		||||
      return
 | 
			
		||||
    if _trace_files_need_precreated(self._parameters.trace_dir):
 | 
			
		||||
      for replica_id in range(0, self._num_replicas):
 | 
			
		||||
      for replica_id in range(0, self._tt_config.num_replicas):
 | 
			
		||||
        trace_file_path = os.path.join(
 | 
			
		||||
            self._parameters.trace_dir,
 | 
			
		||||
            _COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id
 | 
			
		||||
@ -1009,56 +758,25 @@ class TensorTracer(object):
 | 
			
		||||
        if not gfile.Exists(self._parameters.trace_dir):
 | 
			
		||||
          raise RuntimeError('Failed to create %s'%self._parameters.trace_dir)
 | 
			
		||||
 | 
			
		||||
  def _pre_tracing(self, graph, fetches):
 | 
			
		||||
  def _determine_trace_and_create_report(self, graph, ops_in_exec_path):
 | 
			
		||||
    """Work needs to be done prior to TPU or CPU tracing."""
 | 
			
		||||
 | 
			
		||||
    self._check_trace_files()
 | 
			
		||||
    operations = graph.get_operations()
 | 
			
		||||
    (opname_idx_map, tensor_list, self._tensorname_idx_map) = (
 | 
			
		||||
        TensorTracer._make_op_and_tensor_maps(operations))
 | 
			
		||||
    self._write_config_section()
 | 
			
		||||
    self._write_op_list_section(operations)
 | 
			
		||||
    self._write_tensor_list_section(tensor_list, opname_idx_map)
 | 
			
		||||
    # Filter out the operations that won't be executed.
 | 
			
		||||
    # if fetches=None, then ops_in_exec_path = set(operations)
 | 
			
		||||
    ops_in_exec_path = self._filter_execution_path_operations(operations,
 | 
			
		||||
                                                              fetches)
 | 
			
		||||
    self._determine_traced_tensors(graph, ops_in_exec_path)
 | 
			
		||||
    self._write_cache_index_map_section()
 | 
			
		||||
    # Does the topological sort before adding any nodes to the graph.
 | 
			
		||||
    (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
 | 
			
		||||
 | 
			
		||||
    graph_order = tensor_tracer_report.sort_tensors_and_ops(graph)
 | 
			
		||||
    tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION)
 | 
			
		||||
 | 
			
		||||
    report_handler = tensor_tracer_report.TTReportHandle()
 | 
			
		||||
    traced_tensors = self._determine_and_instrument_traced_tensors(
 | 
			
		||||
        graph_order, ops_in_exec_path, tensor_trace_points, report_handler)
 | 
			
		||||
 | 
			
		||||
    tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
 | 
			
		||||
                                                               traced_tensors)
 | 
			
		||||
    if self._use_tensor_values_cache():
 | 
			
		||||
      _create_tensor_values_cache(graph,
 | 
			
		||||
                                  len(self._cache_idx_to_tensor_idx))
 | 
			
		||||
    return (ops_in_exec_path, succeed, sorted_or_cycle)
 | 
			
		||||
 | 
			
		||||
  def _post_tracing(self, succeed, sorted_or_cycle):
 | 
			
		||||
    """Work needs to be done after TPU or CPU tracing."""
 | 
			
		||||
 | 
			
		||||
    self._write_reason_section()
 | 
			
		||||
    self._write_graph_section(succeed, sorted_or_cycle)
 | 
			
		||||
    self._close_report_file()
 | 
			
		||||
 | 
			
		||||
  def _get_checkpoints(self, graph):
 | 
			
		||||
    """Returns the list of Ops that produce the tensors traced with API.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      graph: the graph of Ops.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A set of operation names which should be traced.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
 | 
			
		||||
                                  _TENSOR_TRACER_CHECKPOINT))
 | 
			
		||||
    checkpoint_operations = set()
 | 
			
		||||
    tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION)
 | 
			
		||||
    for (tensor, checkpoint_name) in tensor_tracer_variables:
 | 
			
		||||
      self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
 | 
			
		||||
      checkpoint_operations.add(tensor.op.name)
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
 | 
			
		||||
                                  _TENSOR_TRACER_CHECKPOINT))
 | 
			
		||||
    return checkpoint_operations
 | 
			
		||||
      _create_tensor_values_cache(graph, len(traced_tensors))
 | 
			
		||||
    report_handler.create_report(self._tt_config, self._parameters,
 | 
			
		||||
                                 tensor_trace_order, tensor_trace_points)
 | 
			
		||||
    return tensor_trace_order
 | 
			
		||||
 | 
			
		||||
  def _generate_flush_cache_op(self, graph, start_replica, on_tpu):
 | 
			
		||||
    """Generates an Op that will flush the cache to file.
 | 
			
		||||
@ -1085,9 +803,9 @@ class TensorTracer(object):
 | 
			
		||||
          else:
 | 
			
		||||
            replica_id_str = '%d'%replica_id
 | 
			
		||||
          if self._parameters.trace_dir:
 | 
			
		||||
            output_path = os.path.join(self._parameters.trace_dir,
 | 
			
		||||
                                       _COMPACT_TRACE_FILE_PREFIX) \
 | 
			
		||||
                                       + replica_id_str
 | 
			
		||||
            output_path = (os.path.join(self._parameters.trace_dir,
 | 
			
		||||
                                        _COMPACT_TRACE_FILE_PREFIX)
 | 
			
		||||
                           + replica_id_str)
 | 
			
		||||
            output_stream = _OUTPUT_STREAM_ESCAPE + output_path
 | 
			
		||||
          else:
 | 
			
		||||
            output_stream = sys.stderr
 | 
			
		||||
@ -1153,7 +871,7 @@ class TensorTracer(object):
 | 
			
		||||
    with ops.control_dependencies(op_fetches +
 | 
			
		||||
                                  [tensor.op for tensor in tensor_fetches]):
 | 
			
		||||
      flush_cache_op_list = []
 | 
			
		||||
      for host in range(self._num_hosts):
 | 
			
		||||
      for host in range(self._tt_config.num_hosts):
 | 
			
		||||
        start_replica = host * 8
 | 
			
		||||
        flush_op = self._generate_flush_cache_op(graph, start_replica, on_tpu)
 | 
			
		||||
        flush_cache_op_list.append(flush_op)
 | 
			
		||||
@ -1233,8 +951,8 @@ class TensorTracer(object):
 | 
			
		||||
                       on_tpu=True):
 | 
			
		||||
    """Commong tracing function for both CPU and TPUs.
 | 
			
		||||
 | 
			
		||||
    The caller function should set _device_type, _num_replicas,
 | 
			
		||||
    _num_replicas_per_host, _num_hosts and _replica_id before calling
 | 
			
		||||
    The caller function should set device_type, num_replicas,
 | 
			
		||||
    num_replicas_per_host, num_hosts and replica_id before calling
 | 
			
		||||
    _trace_execution.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1266,15 +984,19 @@ class TensorTracer(object):
 | 
			
		||||
        return math_ops.cast(tensor, dtypes.float32)
 | 
			
		||||
      return tensor
 | 
			
		||||
 | 
			
		||||
    TensorTracer.check_device_type(self._device_type)
 | 
			
		||||
    TensorTracer.check_device_type(self._tt_config.device_type)
 | 
			
		||||
    # Check in_tensor_fetches, and op_fetches and convert them to lists.
 | 
			
		||||
    processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
 | 
			
		||||
    op_fetches = self._process_op_fetches(op_fetches)
 | 
			
		||||
    all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]
 | 
			
		||||
 | 
			
		||||
    # Filter the set of ops that will be executed, and topological sort.
 | 
			
		||||
    (exec_op_set, succeed, sorted_or_cycle) = self._pre_tracing(graph,
 | 
			
		||||
                                                                all_fetches)
 | 
			
		||||
    # Filter out the operations that won't be executed.
 | 
			
		||||
    # if fetches=None, then ops_in_exec_path = set(operations)
 | 
			
		||||
    exec_op_set = self._filter_execution_path_operations(graph.get_operations(),
 | 
			
		||||
                                                         all_fetches)
 | 
			
		||||
    # Write report file, and determine the traced tensors.
 | 
			
		||||
    tensor_trace_order = self._determine_trace_and_create_report(
 | 
			
		||||
        graph, exec_op_set)
 | 
			
		||||
 | 
			
		||||
    tensor_fetch_set = set(processed_t_fetches)
 | 
			
		||||
    tracing_ops = []
 | 
			
		||||
@ -1290,7 +1012,7 @@ class TensorTracer(object):
 | 
			
		||||
      for i in range(len(op.outputs)):
 | 
			
		||||
        out_tensor = op.outputs[i]
 | 
			
		||||
        tensor_name = out_tensor.name
 | 
			
		||||
        if tensor_name not in self._traced_tensorname_to_cache_idx_map:
 | 
			
		||||
        if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
 | 
			
		||||
          continue
 | 
			
		||||
        # Create the list of consumers before calling _preprocess_traced_tensor.
 | 
			
		||||
        # Otherwise, adding control input below, will introduce a cycle in the
 | 
			
		||||
@ -1316,7 +1038,7 @@ class TensorTracer(object):
 | 
			
		||||
          processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor)
 | 
			
		||||
 | 
			
		||||
        if self._use_tensor_values_cache():
 | 
			
		||||
          cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name]
 | 
			
		||||
          cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
 | 
			
		||||
          trace_op = self._save_tensor_value_to_cache_op(graph,
 | 
			
		||||
                                                         cache_idx,
 | 
			
		||||
                                                         processed_out_tensor)
 | 
			
		||||
@ -1324,7 +1046,8 @@ class TensorTracer(object):
 | 
			
		||||
 | 
			
		||||
          def tpu_wrap_trace_fn(tensor, out_tensor_name):
 | 
			
		||||
            """Wraps the trace_fn with outside compilation if on TPUs."""
 | 
			
		||||
            tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name)
 | 
			
		||||
            tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name,
 | 
			
		||||
                                                          tensor_trace_order)
 | 
			
		||||
            if on_tpu:
 | 
			
		||||
              return tpu.outside_compilation(tensor_trace_fn, tensor)
 | 
			
		||||
            else:
 | 
			
		||||
@ -1374,7 +1097,6 @@ class TensorTracer(object):
 | 
			
		||||
                                                            processed_t_fetches,
 | 
			
		||||
                                                            op_fetches,
 | 
			
		||||
                                                            on_tpu=on_tpu)
 | 
			
		||||
    self._post_tracing(succeed, sorted_or_cycle)
 | 
			
		||||
    # processed_t_fetches is a list at this point. Convert it to the same
 | 
			
		||||
    # format as given in tensor_fetches.
 | 
			
		||||
    return self._convert_fetches_to_input_format(tensor_fetches,
 | 
			
		||||
@ -1414,21 +1136,23 @@ class TensorTracer(object):
 | 
			
		||||
      return tensor_fetches
 | 
			
		||||
    else:
 | 
			
		||||
      TensorTracer._traced_graphs.add(graph)
 | 
			
		||||
    self._device_type = _DEVICE_TYPE_TPU
 | 
			
		||||
    self._num_replicas = num_replicas
 | 
			
		||||
    self._num_replicas_per_host = num_replicas_per_host
 | 
			
		||||
    self._num_hosts = num_hosts
 | 
			
		||||
    if self._num_replicas is not None:
 | 
			
		||||
      if self._num_replicas_per_host is None:
 | 
			
		||||
        self._num_replicas_per_host = 8
 | 
			
		||||
      if self._num_hosts is None:
 | 
			
		||||
        self._num_hosts = num_replicas // self._num_replicas_per_host + \
 | 
			
		||||
            (num_replicas % self._num_replicas_per_host > 0)
 | 
			
		||||
 | 
			
		||||
    if self._num_replicas_per_host > 8:
 | 
			
		||||
    self._tt_config.device_type = _DEVICE_TYPE_TPU
 | 
			
		||||
    self._tt_config.num_replicas = num_replicas
 | 
			
		||||
    self._tt_config.num_replicas_per_host = num_replicas_per_host
 | 
			
		||||
    self._tt_config.num_hosts = num_hosts
 | 
			
		||||
    if self._tt_config.num_replicas is not None:
 | 
			
		||||
      if self._tt_config.num_replicas_per_host is None:
 | 
			
		||||
        self._tt_config.num_replicas_per_host = 8
 | 
			
		||||
      if self._tt_config.num_hosts is None:
 | 
			
		||||
        self._tt_config.num_hosts = (
 | 
			
		||||
            num_replicas // self._tt_config.num_replicas_per_host +
 | 
			
		||||
            (num_replicas % self._tt_config.num_replicas_per_host > 0))
 | 
			
		||||
 | 
			
		||||
    if self._tt_config.num_replicas_per_host > 8:
 | 
			
		||||
      # Checks for the assumption in _generate_flush_cache_op().
 | 
			
		||||
      raise RuntimeError('num_replicas_per_host (%d) is '
 | 
			
		||||
                         'greater than 8'%self._num_replicas_per_host)
 | 
			
		||||
                         'greater than 8'%self._tt_config.num_replicas_per_host)
 | 
			
		||||
    if self._parameters.graph_dump_path:
 | 
			
		||||
      graph_io.write_graph(graph, self._parameters.graph_dump_path,
 | 
			
		||||
                           'graph_before_tt.pbtxt')
 | 
			
		||||
@ -1467,10 +1191,10 @@ class TensorTracer(object):
 | 
			
		||||
    else:
 | 
			
		||||
      TensorTracer._traced_graphs.add(graph)
 | 
			
		||||
 | 
			
		||||
    self._device_type = _DEVICE_TYPE_CPU
 | 
			
		||||
    self._num_replicas = 1
 | 
			
		||||
    self._num_replicas_per_host = 1
 | 
			
		||||
    self._num_hosts = 1
 | 
			
		||||
    self._tt_config.device_type = _DEVICE_TYPE_CPU
 | 
			
		||||
    self._tt_config.num_replicas = 1
 | 
			
		||||
    self._tt_config.num_replicas_per_host = 1
 | 
			
		||||
    self._tt_config.num_hosts = 1
 | 
			
		||||
    self._replica_id = 0
 | 
			
		||||
    if self._parameters.graph_dump_path:
 | 
			
		||||
      graph_io.write_graph(graph, self._parameters.graph_dump_path,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										341
									
								
								tensorflow/python/tpu/tensor_tracer_report.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										341
									
								
								tensorflow/python/tpu/tensor_tracer_report.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,341 @@
 | 
			
		||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
# ========================================================================
 | 
			
		||||
"""Tensor Tracer report generation utilities."""
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
import collections
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.platform import gfile
 | 
			
		||||
from tensorflow.python.platform import tf_logging as logging
 | 
			
		||||
 | 
			
		||||
_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
 | 
			
		||||
_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
 | 
			
		||||
_MARKER_SECTION_END = '!!!!!!! section-end:'
 | 
			
		||||
 | 
			
		||||
_SECTION_NAME_CONFIG = 'configuration'
 | 
			
		||||
_SECTION_NAME_REASON = 'reason'
 | 
			
		||||
_SECTION_NAME_OP_LIST = 'op-list'
 | 
			
		||||
_SECTION_NAME_TENSOR_LIST = 'tensor-list'
 | 
			
		||||
_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
 | 
			
		||||
_SECTION_NAME_GRAPH = 'graph'
 | 
			
		||||
_SECTION_NAME_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
 | 
			
		||||
 | 
			
		||||
_FIELD_NAME_VERSION = 'version:'
 | 
			
		||||
_FIELD_NAME_DEVICE = 'device:'
 | 
			
		||||
_FIELD_NAME_TRACE_MODE = 'trace-mode:'
 | 
			
		||||
_FIELD_NAME_SUBMODE = 'submode:'
 | 
			
		||||
_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
 | 
			
		||||
_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
 | 
			
		||||
_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
 | 
			
		||||
_FIELD_NAME_NUM_OPS = 'number-of-ops:'
 | 
			
		||||
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
 | 
			
		||||
_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
 | 
			
		||||
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
 | 
			
		||||
 | 
			
		||||
_CURRENT_VERSION = 'use-outside-compilation'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def topological_sort(g):
 | 
			
		||||
  """Performs topological sort on the given graph.
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
     g: the graph.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
     A pair where the first element indicates if the topological
 | 
			
		||||
     sort succeeded (True if there is no cycle found; False if a
 | 
			
		||||
     cycle is found) and the second element is either the sorted
 | 
			
		||||
     list of nodes or the cycle of nodes found.
 | 
			
		||||
  """
 | 
			
		||||
  def _is_loop_edge(op):
 | 
			
		||||
    """Returns true if the op is the end of a while-loop creating a cycle."""
 | 
			
		||||
    return op.type in ['NextIteration']
 | 
			
		||||
 | 
			
		||||
  def _in_op_degree(op):
 | 
			
		||||
    """Returns the number of incoming edges to the given op.
 | 
			
		||||
 | 
			
		||||
    The edge calculation skips the edges that come from 'NextIteration' ops.
 | 
			
		||||
    NextIteration creates a cycle in the graph. We break cycles by treating
 | 
			
		||||
    this op as 'sink' and ignoring all outgoing edges from it.
 | 
			
		||||
    Args:
 | 
			
		||||
      op: Tf.Operation
 | 
			
		||||
    Returns:
 | 
			
		||||
      the number of incoming edges.
 | 
			
		||||
    """
 | 
			
		||||
    count = 0
 | 
			
		||||
    for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
 | 
			
		||||
      if not _is_loop_edge(op):
 | 
			
		||||
        count += 1
 | 
			
		||||
    return count
 | 
			
		||||
 | 
			
		||||
  sorted_ops = []
 | 
			
		||||
  op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
 | 
			
		||||
 | 
			
		||||
  frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
 | 
			
		||||
  frontier.sort(key=lambda op: op.name)
 | 
			
		||||
  while frontier:
 | 
			
		||||
    op = frontier.pop()
 | 
			
		||||
    # Remove the op from graph, and remove its outgoing edges.
 | 
			
		||||
    sorted_ops.append(op)
 | 
			
		||||
    if _is_loop_edge(op):
 | 
			
		||||
      continue
 | 
			
		||||
    # pylint: disable=protected-access
 | 
			
		||||
    consumers = list(op._control_outputs)
 | 
			
		||||
    # pylint: enable=protected-access
 | 
			
		||||
    for out_tensor in op.outputs:
 | 
			
		||||
      consumers += [consumer_op for consumer_op in out_tensor.consumers()]
 | 
			
		||||
    consumers.sort(key=lambda op: op.name)
 | 
			
		||||
    for consumer in consumers:
 | 
			
		||||
      # For each deleted edge shift the bucket of the vertex.
 | 
			
		||||
      op_in_degree[consumer] -= 1
 | 
			
		||||
      if op_in_degree[consumer] == 0:
 | 
			
		||||
        frontier.append(consumer)
 | 
			
		||||
      if op_in_degree[consumer] < 0:
 | 
			
		||||
        raise ValueError('consumer:%s degree mismatch'%consumer.name)
 | 
			
		||||
 | 
			
		||||
  left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
 | 
			
		||||
  if left_ops:
 | 
			
		||||
    return (True, left_ops)
 | 
			
		||||
  else:
 | 
			
		||||
    assert len(g.get_operations()) == len(sorted_ops)
 | 
			
		||||
    return (False, sorted_ops)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorTracerConfig(object):
 | 
			
		||||
  """Tensor Tracer config object."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self):
 | 
			
		||||
    self.version = _CURRENT_VERSION
 | 
			
		||||
    self.device_type = None
 | 
			
		||||
    self.num_replicas = None
 | 
			
		||||
    self.num_replicas_per_host = None
 | 
			
		||||
    self.num_hosts = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorTraceOrder(object):
 | 
			
		||||
  """Class that is responsible from storing the trace-id of the tensors."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, graph_order, traced_tensors):
 | 
			
		||||
    self.graph_order = graph_order
 | 
			
		||||
    self.traced_tensors = traced_tensors
 | 
			
		||||
    self._create_tensor_maps()
 | 
			
		||||
 | 
			
		||||
  def _create_tensor_maps(self):
 | 
			
		||||
    """Creates tensor to cache id maps."""
 | 
			
		||||
    self.tensorname_to_cache_idx = {}
 | 
			
		||||
    self.cache_idx_to_tensor_idx = []
 | 
			
		||||
    for out_tensor in self.traced_tensors:
 | 
			
		||||
      tensor_name = out_tensor.name
 | 
			
		||||
      if tensor_name in self.tensorname_to_cache_idx:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            'Tensor name %s should not be already in '
 | 
			
		||||
            'tensorname_to_cache_idx'%tensor_name)
 | 
			
		||||
      if tensor_name not in self.graph_order.tensor_to_idx:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            'Tensor name %s is not in the tensor_to_idx'%tensor_name)
 | 
			
		||||
      tensor_idx = self.graph_order.tensor_to_idx[tensor_name]
 | 
			
		||||
      cache_idx = len(self.tensorname_to_cache_idx)
 | 
			
		||||
      self.tensorname_to_cache_idx[tensor_name] = cache_idx
 | 
			
		||||
      self.cache_idx_to_tensor_idx.append(tensor_idx)
 | 
			
		||||
      if len(self.tensorname_to_cache_idx) != len(
 | 
			
		||||
          self.cache_idx_to_tensor_idx):
 | 
			
		||||
        raise RuntimeError('len(self.tensorname_to_cache_idx) != '
 | 
			
		||||
                           'len(self.cache_idx_to_tensor_idx')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sort_tensors_and_ops(graph):
 | 
			
		||||
  """Returns a wrapper that has consistent tensor and op orders."""
 | 
			
		||||
  graph_wrapper = collections.namedtuple('GraphWrapper',
 | 
			
		||||
                                         ['graph', 'operations', 'op_to_idx',
 | 
			
		||||
                                          'tensors', 'tensor_to_idx',
 | 
			
		||||
                                          'contains_cycle',
 | 
			
		||||
                                          'topological_order_or_cycle'])
 | 
			
		||||
  operations = graph.get_operations()
 | 
			
		||||
  op_to_idx = {op.name: index for index, op
 | 
			
		||||
               in enumerate(operations)}
 | 
			
		||||
  tensors = []
 | 
			
		||||
  for op in operations:
 | 
			
		||||
    tensors.extend(op.outputs)
 | 
			
		||||
  tensor_to_idx = {tensor.name: index for index, tensor in
 | 
			
		||||
                   enumerate(tensors)}
 | 
			
		||||
  contains_cycle, topological_order_or_cycle = topological_sort(graph)
 | 
			
		||||
  return graph_wrapper(graph=graph, operations=operations, op_to_idx=op_to_idx,
 | 
			
		||||
                       tensors=tensors, tensor_to_idx=tensor_to_idx,
 | 
			
		||||
                       contains_cycle=contains_cycle,
 | 
			
		||||
                       topological_order_or_cycle=topological_order_or_cycle)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OpenReportFile(object):
 | 
			
		||||
  """Context manager for writing report file."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, tt_parameters):
 | 
			
		||||
    if not tt_parameters.report_file_path:
 | 
			
		||||
      self._report_file = None
 | 
			
		||||
      return
 | 
			
		||||
    try:
 | 
			
		||||
      self._report_file = gfile.Open(tt_parameters.report_file_path, 'w')
 | 
			
		||||
    except IOError as e:
 | 
			
		||||
      raise e
 | 
			
		||||
 | 
			
		||||
  def __enter__(self):
 | 
			
		||||
    return self._report_file
 | 
			
		||||
 | 
			
		||||
  def __exit__(self, unused_type, unused_value, unused_traceback):
 | 
			
		||||
    if self._report_file:
 | 
			
		||||
      self._report_file.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TTReportHandle(object):
 | 
			
		||||
  """Utility class responsible from creating a tensor tracer report."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self):
 | 
			
		||||
    self.instrument_records = {}
 | 
			
		||||
    self._report_file = None
 | 
			
		||||
 | 
			
		||||
  def instrument(self, name, explanation):
 | 
			
		||||
    self.instrument_records[name] = explanation
 | 
			
		||||
 | 
			
		||||
  def instrument_op(self, op, explanation):
 | 
			
		||||
    self.instrument(op.name, explanation)
 | 
			
		||||
 | 
			
		||||
  def instrument_tensor(self, tensor, explanation):
 | 
			
		||||
    self.instrument(tensor.name, explanation)
 | 
			
		||||
 | 
			
		||||
  def create_report(self, tt_config, tt_parameters,
 | 
			
		||||
                    tensor_trace_order, tensor_trace_points):
 | 
			
		||||
    """Creates a report file and writes the trace information."""
 | 
			
		||||
    with OpenReportFile(tt_parameters) as self._report_file:
 | 
			
		||||
      self._write_config_section(tt_config, tt_parameters)
 | 
			
		||||
      self._write_op_list_section(tensor_trace_order.graph_order)
 | 
			
		||||
      self._write_tensor_list_section(tensor_trace_order.graph_order)
 | 
			
		||||
      self._write_trace_points(tensor_trace_points)
 | 
			
		||||
      self._write_cache_index_map_section(tensor_trace_order)
 | 
			
		||||
      self._write_reason_section()
 | 
			
		||||
      self._write_graph_section(tensor_trace_order.graph_order)
 | 
			
		||||
 | 
			
		||||
  def _write_trace_points(self, tensor_trace_points):
 | 
			
		||||
    """Writes the list of checkpoints."""
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
 | 
			
		||||
                                  _SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
 | 
			
		||||
    for (tensor, checkpoint_name) in tensor_trace_points:
 | 
			
		||||
      self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
 | 
			
		||||
                                  _SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
 | 
			
		||||
 | 
			
		||||
  def _write_report(self, content):
 | 
			
		||||
    """Writes the given content to the report."""
 | 
			
		||||
 | 
			
		||||
    line = '%s %s'%(_TRACER_LOG_PREFIX, content)
 | 
			
		||||
    if self._report_file:
 | 
			
		||||
      self._report_file.write(line)
 | 
			
		||||
    else:
 | 
			
		||||
      logging.info(line)
 | 
			
		||||
 | 
			
		||||
  def _write_config_section(self, tt_config, tt_parameters):
 | 
			
		||||
    """Writes the config section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, tt_config.version))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, tt_config.device_type))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE,
 | 
			
		||||
                                  tt_parameters.trace_mode))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE,
 | 
			
		||||
                                  tt_parameters.submode))
 | 
			
		||||
    if tt_parameters.included_cores:
 | 
			
		||||
      self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
 | 
			
		||||
                                    len(tt_parameters.included_cores)))
 | 
			
		||||
    else:
 | 
			
		||||
      self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
 | 
			
		||||
                                    tt_config.num_replicas))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
 | 
			
		||||
                                  tt_config.num_replicas_per_host))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, tt_config.num_hosts))
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
 | 
			
		||||
 | 
			
		||||
  def _write_reason_section(self):
 | 
			
		||||
    """Writes the reason section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON))
 | 
			
		||||
    for key in sorted(self.instrument_records):
 | 
			
		||||
      self._write_report('"%s" %s\n'%(key, self.instrument_records[key]))
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
 | 
			
		||||
 | 
			
		||||
  def _write_op_list_section(self, graph_order):
 | 
			
		||||
    """Writes the Op-list section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
 | 
			
		||||
    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS,
 | 
			
		||||
                                  len(graph_order.operations)))
 | 
			
		||||
    for i in range(0, len(graph_order.operations)):
 | 
			
		||||
      op = graph_order.operations[i]
 | 
			
		||||
      line = '%d "%s" %s'%(i, op.name, op.type)
 | 
			
		||||
      for out_tensor in op.outputs:
 | 
			
		||||
        if out_tensor.name not in graph_order.tensor_to_idx:
 | 
			
		||||
          raise ValueError(
 | 
			
		||||
              'out_tensor %s is not in tensor_to_idx'%out_tensor.name)
 | 
			
		||||
        line += ' %d'%graph_order.tensor_to_idx[out_tensor.name]
 | 
			
		||||
      line += '\n'
 | 
			
		||||
      self._write_report(line)
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
 | 
			
		||||
 | 
			
		||||
  def _write_tensor_list_section(self, graph_order):
 | 
			
		||||
    """Writes the tensor-list section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
 | 
			
		||||
                                  _SECTION_NAME_TENSOR_LIST))
 | 
			
		||||
    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS,
 | 
			
		||||
                                  len(graph_order.tensors)))
 | 
			
		||||
    for i in range(0, len(graph_order.tensors)):
 | 
			
		||||
      tensor = graph_order.tensors[i]
 | 
			
		||||
      line = '%d "%s"'%(i, tensor.name)
 | 
			
		||||
      consumers = tensor.consumers()
 | 
			
		||||
      consumers.sort(key=lambda op: op.name)
 | 
			
		||||
      for consumer_op in consumers:
 | 
			
		||||
        if consumer_op.name not in graph_order.op_to_idx:
 | 
			
		||||
          raise ValueError(
 | 
			
		||||
              'consumer_op %s is not in op_to_idx'%consumer_op.name)
 | 
			
		||||
        line += ' %d'%graph_order.op_to_idx[consumer_op.name]
 | 
			
		||||
      line += '\n'
 | 
			
		||||
      self._write_report(line)
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
 | 
			
		||||
                                  _SECTION_NAME_TENSOR_LIST))
 | 
			
		||||
 | 
			
		||||
  def _write_cache_index_map_section(self, tensor_trace_order):
 | 
			
		||||
    """Writes the mapping from cache index to tensor index to the report."""
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
 | 
			
		||||
                                  _SECTION_NAME_CACHE_INDEX_MAP))
 | 
			
		||||
    self._write_report('%s %d\n'%(
 | 
			
		||||
        _FIELD_NAME_NUM_CACHE_INDICES,
 | 
			
		||||
        len(tensor_trace_order.cache_idx_to_tensor_idx)))
 | 
			
		||||
    for cache_idx in range(0, len(tensor_trace_order.cache_idx_to_tensor_idx)):
 | 
			
		||||
      tensor_idx = tensor_trace_order.cache_idx_to_tensor_idx[cache_idx]
 | 
			
		||||
      line = '%d %d\n'%(cache_idx, tensor_idx)
 | 
			
		||||
      self._write_report(line)
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
 | 
			
		||||
                                  _SECTION_NAME_CACHE_INDEX_MAP))
 | 
			
		||||
 | 
			
		||||
  def _write_graph_section(self, graph_order):
 | 
			
		||||
    """Writes the graph section of the report."""
 | 
			
		||||
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH))
 | 
			
		||||
    self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED,
 | 
			
		||||
                                  not graph_order.contains_cycle))
 | 
			
		||||
    l = list(graph_order.topological_order_or_cycle)
 | 
			
		||||
    for i in range(0, len(l)):
 | 
			
		||||
      self._write_report('%d "%s"\n'%(i, l[i].name))
 | 
			
		||||
    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH))
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user