Enhance the Tensor-Tracer in the following ways:
(1) Able to trace tensors when the model is executed on the CPU. (previously, it could only trace when the model is executed on TPU) (2) Allow the user to specify the op-names and op-types to be excluded or included for tracing via regular expressions. (3) Two new trace modes: (1) tracing the vector norm of the tensor and (2) tracing the maximum of the absolute values of all elements in the tensor. (4) Attach the replica-ID to a traced tensor value so that the post-processing tool (Tensor-Inspector) can reconstruct the whole tensor from all replicas. (5) An API to trace tensors programmatically. (6) Allow writing the trace to stdout (previously, it must be written to a file). PiperOrigin-RevId: 225112219
This commit is contained in:
parent
4b974cf1c1
commit
8ac99aa0ec
@ -21,44 +21,56 @@ from __future__ import print_function
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
import sys
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
|
||||
_DEVICE_TYPE_TPU = 'tpu'
|
||||
_DEVICE_TYPE_CPU = 'cpu'
|
||||
_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP'
|
||||
_TRACE_MODE_NAN_INF = 'nan-inf'
|
||||
_TRACE_MODE_PART_TENSOR = 'part-tensor'
|
||||
_TRACE_MODE_PART_TENSOR_SIZE = 3
|
||||
_TRACE_MODE_FULL_TENSOR = 'full-tensor'
|
||||
_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
|
||||
_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace'
|
||||
_RECORD_FILTERED_OUT = 'not-traced-filtered-out'
|
||||
_RECORD_SCALAR = 'not-traced-scalar'
|
||||
_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
|
||||
_RECORD_GET_TRACED = 'get-traced'
|
||||
_TRACE_MODE_NORM = 'norm'
|
||||
_TRACE_MODE_MAX_ABS = 'max-abs'
|
||||
_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
|
||||
_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
|
||||
_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar'
|
||||
_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op'
|
||||
_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch'
|
||||
_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
|
||||
_REASON_SCALAR_GET_TRACED = 'traced-scalar'
|
||||
_REASON_TENSOR_GET_TRACED = 'traced-tensor'
|
||||
_REASON_USER_INCLUDED = 'traced-user-included'
|
||||
_REASON_USER_EXCLUDED = 'not-traced-user-excluded'
|
||||
_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
|
||||
_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
|
||||
_MARKER_SECTION_END = '!!!!!!! section-end:'
|
||||
_SECTION_NAME_CONFIG = 'configuration'
|
||||
_SECTION_NAME_REASON = 'reason'
|
||||
_SECTION_NAME_OP_LIST = 'op-list'
|
||||
_SECTION_NAME_TENSOR_LIST = 'tensor-list'
|
||||
_SECTION_NAME_GRAPH = 'graph'
|
||||
_FIELD_NAME_VERSION = 'version:'
|
||||
_FIELD_NAME_DEVICE = 'device:'
|
||||
_FIELD_NAME_TRACE_MODE = 'trace-mode:'
|
||||
_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
|
||||
_FIELD_NAME_NUM_OPS = 'number-of-ops:'
|
||||
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
|
||||
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
|
||||
_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
|
||||
_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
|
||||
@ -66,13 +78,72 @@ _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"')
|
||||
_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)')
|
||||
_FLAG_NAME_ENABLE = 'enable'
|
||||
_FLAG_NAME_TRACE_MODE = 'trace_mode'
|
||||
_FLAG_NAME_INTERESTING_OPS = 'interesting_ops'
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops'
|
||||
_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
|
||||
_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
|
||||
_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
|
||||
_FLAG_NAME_TRACE_FILE = 'trace_file_path'
|
||||
_FLAG_NAME_REPORT_FILE = 'report_file_path'
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
|
||||
_FLAG_NAME_OP_RANGE = 'op_range'
|
||||
_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
|
||||
_OUTPUT_STREAM_ESCAPE = 'file://'
|
||||
_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
|
||||
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
|
||||
_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
|
||||
|
||||
|
||||
def tensor_checkpoint(tensor, checkpoint_name):
|
||||
"""Adds a checkpoint with the given checkpoint name for the given tensor.
|
||||
|
||||
The tensor will be added to the list of tensors that will be traced by the
|
||||
tensor tracer.
|
||||
|
||||
Args:
|
||||
tensor: the tensor object for which the tracing is requested.
|
||||
checkpoint_name: a string name for the checkpoint. This name has to be a
|
||||
unique name if used within model comparison. The tensors that have the same
|
||||
checkpoint identifier is compared in model comparison.
|
||||
Returns:
|
||||
The provided tensor.
|
||||
"""
|
||||
|
||||
tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION)
|
||||
tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION,
|
||||
(tensor, checkpoint_name))
|
||||
return tensor
|
||||
|
||||
|
||||
def keras_layer_checkpoint(layer, checkpoint_name):
|
||||
"""An interface for adding the tensor outputs of a keras layer.
|
||||
|
||||
Encapsulates tensor_checkpoint.
|
||||
|
||||
Args:
|
||||
layer: A keras layer.
|
||||
checkpoint_name: a string name for the checkpoint. This name has to be a
|
||||
unique name if used within model comparison. The tensors that have the same
|
||||
checkpoint identifier is compared in model comparison.
|
||||
|
||||
Returns:
|
||||
The provided layer.
|
||||
"""
|
||||
try:
|
||||
outputs = layer.output
|
||||
if tensor_util.is_tensor(outputs):
|
||||
tensor_checkpoint(outputs, '%s' % (checkpoint_name))
|
||||
else:
|
||||
idx = 0
|
||||
for output_tensor in outputs:
|
||||
if tensor_util.is_tensor(outputs):
|
||||
tensor_checkpoint(output_tensor, '%s_%d' % (checkpoint_name, idx))
|
||||
idx += 1
|
||||
except AttributeError:
|
||||
pass
|
||||
except RuntimeError:
|
||||
pass
|
||||
return layer
|
||||
|
||||
|
||||
class TensorTracer(object):
|
||||
@ -105,6 +176,34 @@ class TensorTracer(object):
|
||||
match = _FLAG_NO_QUOTE_PAT.match(flags, pos)
|
||||
return match
|
||||
|
||||
@staticmethod
|
||||
def validate_flag_names():
|
||||
"""Validates if the TensorTrace flags passed are valid."""
|
||||
valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE,
|
||||
_FLAG_NAME_EXCLUDED_OPNAMES,
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES,
|
||||
_FLAG_NAME_INCLUDED_OPNAMES,
|
||||
_FLAG_NAME_INCLUDED_OPTYPES,
|
||||
_FLAG_NAME_TRACE_FILE, _FLAG_NAME_REPORT_FILE,
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS,
|
||||
_FLAG_NAME_OP_RANGE]
|
||||
tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR)
|
||||
if not tensor_tracer_flags:
|
||||
return
|
||||
pos = 0
|
||||
while True:
|
||||
match = TensorTracer._match_next_flag(tensor_tracer_flags, pos)
|
||||
if not match:
|
||||
break
|
||||
flag_name = match.group(1)
|
||||
if flag_name not in valid_flag_names:
|
||||
raise ValueError(
|
||||
'The flag name "%s" passed via the environment variable "%s" '
|
||||
'is invalid. Valid flag names are:'
|
||||
'\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names))
|
||||
pos = match.end()
|
||||
|
||||
@staticmethod
|
||||
def print_flag_values():
|
||||
"""Prints all TensorTracer flags passed via environment variables."""
|
||||
@ -146,6 +245,20 @@ class TensorTracer(object):
|
||||
pos = match.end()
|
||||
return ''
|
||||
|
||||
@staticmethod
|
||||
def flag_value_to_re_list(flag_name):
|
||||
"""Converts list of strings to compiled RE."""
|
||||
|
||||
re_list = []
|
||||
flag_value = TensorTracer.get_flag_value(flag_name)
|
||||
if not flag_value:
|
||||
return re_list
|
||||
list_of_values = flag_value.split()
|
||||
for v in list_of_values:
|
||||
r = re.compile(v)
|
||||
re_list.append(r)
|
||||
return re_list
|
||||
|
||||
@staticmethod
|
||||
def is_enabled():
|
||||
"""Returns True if TensorTracer is enabled."""
|
||||
@ -186,29 +299,67 @@ class TensorTracer(object):
|
||||
"""Checks if the given trace mode is valid."""
|
||||
|
||||
valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR,
|
||||
_TRACE_MODE_FULL_TENSOR]
|
||||
_TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM,
|
||||
_TRACE_MODE_MAX_ABS]
|
||||
if trace_mode not in valid_trace_modes:
|
||||
raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
|
||||
'Valid trace modes are: %s'%(trace_mode,
|
||||
valid_trace_modes))
|
||||
|
||||
@staticmethod
|
||||
def should_trace(device_type, op):
|
||||
"""Returns True if the given Op should be traced."""
|
||||
def unsafe_op(op):
|
||||
"""Returns True if this op is not safe to be traced."""
|
||||
|
||||
if device_type != _DEVICE_TYPE_TPU:
|
||||
raise ValueError('Non TPU device type is not supported')
|
||||
if control_flow_util.IsInCond(op):
|
||||
return True
|
||||
# Reasons for not including following op types:
|
||||
# Assign: cause incorrect result with CPU tracing.
|
||||
# others: compilation problems.
|
||||
if op.type in ['Assign', 'Pack', 'Shape', 'Reshape', 'ArgMin', 'ArgMax']:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def device_mismatch(device_type, op):
|
||||
if device_type == _DEVICE_TYPE_TPU:
|
||||
# pylint: disable=protected-access
|
||||
return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr
|
||||
# pylint: enable=protected-access
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def unsafe_scalar_trace(op):
|
||||
"""Return true if scalar output tensor from Op is not safe to be traced."""
|
||||
|
||||
# Tracing the following causes cycle in the graph on TPU.
|
||||
if op.type in ['LoopCond', 'Enter', 'Merge', 'Const',
|
||||
'Switch', 'Less', 'ReadVariableOp']:
|
||||
return True
|
||||
# Tracing the following will cause casting-issue
|
||||
# with the norm tracing mode or other compilation issues on CPU.
|
||||
if op.type in ['VarHandleOp', 'IteratorToStringHandle',
|
||||
'IteratorGetNext', 'OneShotIterator',
|
||||
'IteratorV2', 'MakeIterator',
|
||||
'BatchDatasetV2', 'MapDataset',
|
||||
'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
|
||||
'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def less_interesting_op(op):
|
||||
"""Returns True if the given Op is not an interesting one to be traced."""
|
||||
|
||||
include_less_interesting = TensorTracer.get_flag_value(
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS)
|
||||
if include_less_interesting:
|
||||
return False
|
||||
if op.type in ['Reshape', 'ArgMin', 'ArgMax']:
|
||||
return False
|
||||
# pylint: disable=protected-access
|
||||
return tpu._TPU_REPLICATE_ATTR in op.node_def.attr
|
||||
# pylint: enable=protected-access
|
||||
return op.type in ['Const', 'Identity', 'Cast', 'Shape']
|
||||
|
||||
@staticmethod
|
||||
def reason(op_idx, details):
|
||||
"""Returns why the Op at op_idx is traced or not."""
|
||||
"""Returns reason why the Op at op_idx is traced or not."""
|
||||
|
||||
return '%d %s'%(op_idx, details)
|
||||
|
||||
@staticmethod
|
||||
@ -274,6 +425,33 @@ class TensorTracer(object):
|
||||
assert len(unsorted_ops) == len(sorted_ops)
|
||||
return (True, sorted_ops)
|
||||
|
||||
@staticmethod
|
||||
def _make_op_and_tensor_maps(op_list):
|
||||
"""Creates various maps and lists from op_list.
|
||||
|
||||
Args:
|
||||
op_list: a list of Ops
|
||||
|
||||
Returns:
|
||||
opname_idx_map: a map from Op's name to its index in op_list.
|
||||
tensor_list: a list of output tensors of the Ops in op_list.
|
||||
tensorname_idx_map: a map from output tensor name to its index
|
||||
in tensor_list.
|
||||
"""
|
||||
|
||||
opname_idx_map = {}
|
||||
tensor_list = []
|
||||
tensorname_idx_map = {}
|
||||
for op_id, op in enumerate(op_list):
|
||||
if op.name in opname_idx_map:
|
||||
raise ValueError('Duplicated Op name: %s'%op.name)
|
||||
opname_idx_map[op.name] = op_id
|
||||
for output_tensor in op.outputs:
|
||||
if output_tensor.name not in tensorname_idx_map:
|
||||
tensor_list.append(output_tensor)
|
||||
tensorname_idx_map[output_tensor.name] = len(tensor_list)-1
|
||||
return (opname_idx_map, tensor_list, tensorname_idx_map)
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a TensorTracer.
|
||||
|
||||
@ -281,16 +459,20 @@ class TensorTracer(object):
|
||||
"""
|
||||
self._version = 'use-outside-compilation'
|
||||
self._device_type = None
|
||||
TensorTracer.validate_flag_names()
|
||||
self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE)
|
||||
if not self._trace_mode:
|
||||
self._trace_mode = _TRACE_MODE_NAN_INF
|
||||
TensorTracer.check_trace_mode(self._trace_mode)
|
||||
self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE
|
||||
self._instrument_records = {}
|
||||
interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS)
|
||||
self._selected_ops = interesting_ops.split()
|
||||
self._set_trace_file_path()
|
||||
self._set_report_file()
|
||||
self._set_op_range()
|
||||
self._set_excluded_opnames()
|
||||
self._set_excluded_optypes()
|
||||
self._set_included_opnames()
|
||||
self._set_included_optypes()
|
||||
self._num_replicas = None
|
||||
self._replica_id = None
|
||||
|
||||
@ -318,10 +500,7 @@ class TensorTracer(object):
|
||||
"""Sets the path of the output trace file."""
|
||||
|
||||
self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE)
|
||||
if not self._trace_file_path:
|
||||
raise ValueError('--%s is not set in the environment variable %s'
|
||||
%(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR))
|
||||
elif TensorTracer.use_test_undeclared_outputs_dir():
|
||||
if self._trace_file_path and TensorTracer.use_test_undeclared_outputs_dir():
|
||||
if os.path.isabs(self._trace_file_path):
|
||||
raise ValueError('If use_test_undeclared_outputs_dir is set,'
|
||||
'trace_file_path cannot be an absolute path (%s)'
|
||||
@ -330,6 +509,22 @@ class TensorTracer(object):
|
||||
self._trace_file_path = os.path.join(outputs_dir,
|
||||
self._trace_file_path)
|
||||
|
||||
def _set_report_file(self):
|
||||
"""Sets the path of the output report file."""
|
||||
|
||||
self._report_file_path = TensorTracer.get_flag_value(_FLAG_NAME_REPORT_FILE)
|
||||
if not self._report_file_path:
|
||||
self._report_file = None
|
||||
return
|
||||
try:
|
||||
self._report_file = gfile.Open(self._report_file_path, 'w')
|
||||
except IOError as e:
|
||||
raise e
|
||||
|
||||
def _close_report_file(self):
|
||||
if self._report_file:
|
||||
self._report_file.close()
|
||||
|
||||
def _set_op_range(self):
|
||||
"""Sets the index range of the Ops that we will consider tracing."""
|
||||
|
||||
@ -350,19 +545,48 @@ class TensorTracer(object):
|
||||
return False
|
||||
return self._op_range[1] < 0 or idx <= self._op_range[1]
|
||||
|
||||
def _set_excluded_opnames(self):
|
||||
self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list(
|
||||
_FLAG_NAME_EXCLUDED_OPNAMES)
|
||||
|
||||
def _set_excluded_optypes(self):
|
||||
self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list(
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES)
|
||||
|
||||
def _set_included_opnames(self):
|
||||
self._included_opname_re_list = TensorTracer.flag_value_to_re_list(
|
||||
_FLAG_NAME_INCLUDED_OPNAMES)
|
||||
|
||||
def _set_included_optypes(self):
|
||||
self._included_optype_re_list = TensorTracer.flag_value_to_re_list(
|
||||
_FLAG_NAME_INCLUDED_OPTYPES)
|
||||
|
||||
def _is_user_included_op(self, op):
|
||||
for opname_re in self._included_opname_re_list:
|
||||
if opname_re.match(op.name):
|
||||
return True
|
||||
for optype_re in self._included_optype_re_list:
|
||||
if optype_re.match(op.type):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_user_excluded_op(self, op):
|
||||
for opname_re in self._excluded_opname_re_list:
|
||||
if opname_re.match(op.name):
|
||||
return True
|
||||
for optype_re in self._excluded_optype_re_list:
|
||||
if optype_re.match(op.type):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _write_report(self, content):
|
||||
"""Writes the given content to the report."""
|
||||
|
||||
logging.info('%s %s'%(_TRACER_LOG_PREFIX, content))
|
||||
|
||||
def _is_selected_op(self, op_name):
|
||||
"""Returns True if the Op with op_name is selected to be traced."""
|
||||
|
||||
if not self._selected_ops:
|
||||
return True
|
||||
if op_name in self._selected_ops:
|
||||
return True
|
||||
return False
|
||||
line = '%s %s'%(_TRACER_LOG_PREFIX, content)
|
||||
if self._report_file:
|
||||
self._report_file.write(line)
|
||||
else:
|
||||
logging.info(line)
|
||||
|
||||
def _write_config_section(self):
|
||||
"""Writes the config section of the report."""
|
||||
@ -382,15 +606,42 @@ class TensorTracer(object):
|
||||
self._write_report('"%s" %s\n'%(key, self._instrument_records[key]))
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
|
||||
|
||||
def _write_op_list_section(self, op_list):
|
||||
def _write_op_list_section(self, op_list, tensorname_idx_map):
|
||||
"""Writes the Op-list section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
|
||||
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list)))
|
||||
for i in range(0, len(op_list)):
|
||||
self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type))
|
||||
op = op_list[i]
|
||||
line = '%d "%s" %s'%(i, op.name, op.type)
|
||||
for out_tensor in op.outputs:
|
||||
if out_tensor.name not in tensorname_idx_map:
|
||||
raise ValueError(
|
||||
'out_tensor %s is not in tensorname_idx_map'%out_tensor.name)
|
||||
line += ' %d'%tensorname_idx_map[out_tensor.name]
|
||||
line += '\n'
|
||||
self._write_report(line)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
|
||||
|
||||
def _write_tensor_list_section(self, tensor_list, opname_idx_map):
|
||||
"""Writes the tensor-list section of the report."""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
|
||||
_SECTION_NAME_TENSOR_LIST))
|
||||
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list)))
|
||||
for i in range(0, len(tensor_list)):
|
||||
tensor = tensor_list[i]
|
||||
line = '%d "%s"'%(i, tensor.name)
|
||||
for consumer_op in tensor.consumers():
|
||||
if consumer_op.name not in opname_idx_map:
|
||||
raise ValueError(
|
||||
'consumer_op %s is not in opname_idx_map'%consumer_op.name)
|
||||
line += ' %d'%opname_idx_map[consumer_op.name]
|
||||
line += '\n'
|
||||
self._write_report(line)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
|
||||
_SECTION_NAME_TENSOR_LIST))
|
||||
|
||||
def _write_graph_section(self, succeed, sorted_or_cycle):
|
||||
"""Writes the graph section of the report."""
|
||||
|
||||
@ -422,7 +673,7 @@ class TensorTracer(object):
|
||||
Args:
|
||||
op_name: the name of the Op that outputs the tensor to be printed.
|
||||
output_idx: which output of the Op it is (0 means the first output).
|
||||
num_elements: number of elements to print.
|
||||
num_elements: number of elements to print (-1 means print all).
|
||||
tensor: the tensor needs to be returned.
|
||||
output_tensor: the tensor needs to be printed.
|
||||
|
||||
@ -430,10 +681,13 @@ class TensorTracer(object):
|
||||
The same tensor passed via the "tensor" argument.
|
||||
"""
|
||||
msg = '"%s:%d" '%(op_name, output_idx)
|
||||
output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path
|
||||
if self._trace_file_path:
|
||||
output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path
|
||||
else:
|
||||
output_stream = sys.stderr
|
||||
print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor),
|
||||
' @', self._replica_id,
|
||||
'\n', output_tensor,
|
||||
'\n', output_tensor, '\n',
|
||||
summarize=num_elements,
|
||||
output_stream=output_stream)
|
||||
with ops.control_dependencies([print_op]):
|
||||
@ -442,7 +696,8 @@ class TensorTracer(object):
|
||||
def _detect_nan_inf(tensor):
|
||||
"""Trace function for detecting any NaN/Inf in the tensor."""
|
||||
|
||||
if tensor.dtype.is_floating:
|
||||
if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
|
||||
dtypes.float16):
|
||||
# Since host can't handle bf16, always convert tensor to f32.
|
||||
tensor = math_ops.cast(tensor, dtypes.float32)
|
||||
output_tensor = math_ops.reduce_any(
|
||||
@ -450,12 +705,19 @@ class TensorTracer(object):
|
||||
gen_math_ops.is_inf(tensor)))
|
||||
else:
|
||||
output_tensor = constant_op.constant(0)
|
||||
return _print_tensor(op_name, output_idx, 1, tensor, output_tensor)
|
||||
return _print_tensor(op_name, output_idx, -1, tensor, output_tensor)
|
||||
|
||||
def _show_global_step(tensor):
|
||||
"""Trace function for printing the global step count."""
|
||||
def _show_norm(tensor):
|
||||
tensor = math_ops.cast(tensor, dtypes.float64)
|
||||
output_tensor = linalg_ops.norm(tensor)
|
||||
return _print_tensor(op_name, output_idx, -1, tensor, output_tensor)
|
||||
|
||||
return _print_tensor(op_name, output_idx, 1, tensor, tensor)
|
||||
def _show_max_abs(tensor):
|
||||
output_tensor = math_ops.cast(math_ops.reduce_max(math_ops.abs(tensor)),
|
||||
dtypes.float64)
|
||||
zero = constant_op.constant(0, dtypes.float64)
|
||||
output_tensor = gen_math_ops.maximum(zero, output_tensor)
|
||||
return _print_tensor(op_name, output_idx, -1, tensor, output_tensor)
|
||||
|
||||
def _show_part_tensor(tensor):
|
||||
"""Trace function for printing part of the tensor."""
|
||||
@ -468,23 +730,139 @@ class TensorTracer(object):
|
||||
|
||||
return _print_tensor(op_name, output_idx, -1, tensor, tensor)
|
||||
|
||||
if op_name == _GLOBAL_STEP_OP_NAME:
|
||||
return _show_global_step
|
||||
if self._trace_mode == _TRACE_MODE_NAN_INF:
|
||||
return _detect_nan_inf
|
||||
if self._trace_mode == _TRACE_MODE_PART_TENSOR:
|
||||
return _show_part_tensor
|
||||
if self._trace_mode == _TRACE_MODE_FULL_TENSOR:
|
||||
return _show_full_tensor
|
||||
if self._trace_mode == _TRACE_MODE_NORM:
|
||||
return _show_norm
|
||||
if self._trace_mode == _TRACE_MODE_MAX_ABS:
|
||||
return _show_max_abs
|
||||
|
||||
raise RuntimeError('Tensor trace fun for %s is not yet implemented'
|
||||
%self._trace_mode)
|
||||
|
||||
def _skip_op(self, op_id, op, user_included, user_excluded):
|
||||
"""Returns True if we should not trace Op."""
|
||||
|
||||
if user_included:
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_USER_INCLUDED)
|
||||
return False
|
||||
if user_excluded:
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_USER_EXCLUDED)
|
||||
return True
|
||||
if not self._inside_op_range(op_id):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_OUTSIDE_OP_RANGE)
|
||||
return True
|
||||
if TensorTracer.unsafe_op(op):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_UNSAFE_OP)
|
||||
return True
|
||||
if TensorTracer.device_mismatch(self._device_type, op):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_DEVICE_MISMATCH)
|
||||
return True
|
||||
if TensorTracer.less_interesting_op(op):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_LESS_INTERESTING_OP)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _skip_tensor(self, op_id, out_tensor, user_included,
|
||||
user_excluded):
|
||||
"""Returns True if we should not trace out_tensor."""
|
||||
|
||||
# Skips a tensor if the tensor has a non-numeric type.
|
||||
# Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
|
||||
# because it also excludes tensors with dtypes, bool, and
|
||||
# float32_ref, which we actually want to trace.
|
||||
non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
|
||||
dtypes.string])
|
||||
if out_tensor.dtype in non_numeric_tensor_types:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_NON_NUMERIC_TENSOR)
|
||||
return True
|
||||
|
||||
if user_included:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_USER_INCLUDED)
|
||||
return False
|
||||
if user_excluded:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_USER_EXCLUDED)
|
||||
return True
|
||||
if not out_tensor.get_shape().is_fully_defined():
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_DYNAMIC_SHAPE)
|
||||
return True
|
||||
rank = len(out_tensor.shape)
|
||||
if rank < 1:
|
||||
# scalar
|
||||
if TensorTracer.unsafe_scalar_trace(out_tensor.op):
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_UNSAFE_SCALAR)
|
||||
return True
|
||||
else:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_SCALAR_GET_TRACED)
|
||||
return False
|
||||
else:
|
||||
# tensor
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_TENSOR_GET_TRACED)
|
||||
return False
|
||||
|
||||
def _pre_tracing(self, graph):
|
||||
"""Work needs to be done prior to TPU or CPU tracing."""
|
||||
|
||||
operations = graph.get_operations()
|
||||
(opname_idx_map, tensor_list, tensorname_idx_map) = (
|
||||
TensorTracer._make_op_and_tensor_maps(operations))
|
||||
self._write_config_section()
|
||||
self._write_op_list_section(operations, tensorname_idx_map)
|
||||
self._write_tensor_list_section(tensor_list, opname_idx_map)
|
||||
# Does the topological sort before adding any nodes to the graph.
|
||||
(succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
|
||||
return (operations, succeed, sorted_or_cycle)
|
||||
|
||||
def _post_tracing(self, succeed, sorted_or_cycle):
|
||||
"""Work needs to be done after TPU or CPU tracing."""
|
||||
|
||||
self._write_reason_section()
|
||||
self._write_graph_section(succeed, sorted_or_cycle)
|
||||
self._close_report_file()
|
||||
|
||||
def _get_checkpoints(self, graph):
|
||||
"""Returns the list of Ops that produce the tensors traced with API.
|
||||
|
||||
Args:
|
||||
graph: the graph of Ops.
|
||||
|
||||
Returns:
|
||||
A set of operation names which should be traced.
|
||||
"""
|
||||
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
|
||||
_TENSOR_TRACER_CHECKPOINT))
|
||||
checkpoint_operations = set()
|
||||
tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION)
|
||||
for (tensor, checkpoint_name) in tensor_tracer_variables:
|
||||
self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
|
||||
checkpoint_operations.add(tensor.op.name)
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
|
||||
_TENSOR_TRACER_CHECKPOINT))
|
||||
return checkpoint_operations
|
||||
|
||||
def trace_tpu(self, graph, result_tensor, num_replicas=None):
|
||||
"""Traces the tensors generated by TPU Ops in a TF graph.
|
||||
|
||||
Args:
|
||||
graph: the graph of Ops.
|
||||
graph: the graph of Ops executed on the TPU.
|
||||
result_tensor: a result tensor of evaluating the graph.
|
||||
num_replicas: number of replicas used on the TPU.
|
||||
|
||||
@ -502,38 +880,22 @@ class TensorTracer(object):
|
||||
TensorTracer.check_device_type(self._device_type)
|
||||
result_tensor_copy = self._add_replica_id_to_graph(num_replicas,
|
||||
result_tensor)
|
||||
self._write_config_section()
|
||||
(operations, succeed, sorted_or_cycle) = self._pre_tracing(graph)
|
||||
tracing_ops = []
|
||||
operations = graph.get_operations()
|
||||
self._write_op_list_section(operations)
|
||||
# Does the topological sort before adding any nodes to the graph.
|
||||
(succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
|
||||
checkpoint_operations = self._get_checkpoints(graph)
|
||||
|
||||
for op_id, op in enumerate(operations):
|
||||
if not self._inside_op_range(op_id):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _RECORD_OUTSIDE_OP_RANGE)
|
||||
if checkpoint_operations and op.name not in checkpoint_operations:
|
||||
continue
|
||||
if not TensorTracer.should_trace(self._device_type, op):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _RECORD_SHOULD_NOT_TRACE)
|
||||
continue
|
||||
if not self._is_selected_op(op.name):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _RECORD_FILTERED_OUT)
|
||||
user_included = self._is_user_included_op(op)
|
||||
user_excluded = self._is_user_excluded_op(op)
|
||||
if self._skip_op(op_id, op, user_included, user_excluded):
|
||||
continue
|
||||
for i in range(len(op.outputs)):
|
||||
out_tensor = op.outputs[i]
|
||||
if not out_tensor.get_shape().is_fully_defined():
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _RECORD_DYNAMIC_SHAPE)
|
||||
continue # cannot trace tensors with dynamic shape.
|
||||
rank = len(out_tensor.shape)
|
||||
if rank < 1:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _RECORD_SCALAR)
|
||||
continue # cannot trace scalar.
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _RECORD_GET_TRACED)
|
||||
if self._skip_tensor(op_id, out_tensor, user_included,
|
||||
user_excluded):
|
||||
continue
|
||||
consumers = out_tensor.consumers()
|
||||
trace_op = tpu.outside_compilation(
|
||||
self._make_tensor_trace_fun(op.name, i), out_tensor)
|
||||
@ -546,8 +908,45 @@ class TensorTracer(object):
|
||||
# if there is no consumer, we will add the control dependence later
|
||||
# when we add the control dependency to the output operations.
|
||||
tracing_ops.append(trace_op)
|
||||
|
||||
self._write_reason_section()
|
||||
self._write_graph_section(succeed, sorted_or_cycle)
|
||||
|
||||
self._post_tracing(succeed, sorted_or_cycle)
|
||||
return (result_tensor_copy, tracing_ops)
|
||||
|
||||
def trace_cpu(self, graph):
|
||||
"""Traces the tensors generated by CPU Ops in a TF graph.
|
||||
|
||||
Args:
|
||||
graph: the graph of Ops executed on the CPU.
|
||||
|
||||
Returns:
|
||||
tracing_calls: a map from keys to trace calls.
|
||||
A key is constructed from an Op's name.
|
||||
A trace call consists of a function and a tensor (
|
||||
the function will be invoked with the tensor).
|
||||
"""
|
||||
|
||||
self._device_type = _DEVICE_TYPE_CPU
|
||||
TensorTracer.check_device_type(self._device_type)
|
||||
self._num_replicas = 1
|
||||
self._replica_id = 0
|
||||
(operations, succeed, sorted_or_cycle) = self._pre_tracing(graph)
|
||||
tracing_calls = {}
|
||||
checkpoint_operations = self._get_checkpoints(graph)
|
||||
|
||||
for op_id, op in enumerate(operations):
|
||||
if checkpoint_operations and op.name not in checkpoint_operations:
|
||||
continue
|
||||
user_included = self._is_user_included_op(op)
|
||||
user_excluded = self._is_user_excluded_op(op)
|
||||
if self._skip_op(op_id, op, user_included, user_excluded):
|
||||
continue
|
||||
for i in range(len(op.outputs)):
|
||||
out_tensor = op.outputs[i]
|
||||
if self._skip_tensor(op_id, out_tensor, user_included,
|
||||
user_excluded):
|
||||
continue
|
||||
trace_fun = self._make_tensor_trace_fun(op.name, i)
|
||||
trace_call = (trace_fun, [out_tensor])
|
||||
trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i)
|
||||
tracing_calls[trace_call_key] = trace_call
|
||||
self._post_tracing(succeed, sorted_or_cycle)
|
||||
return tracing_calls
|
||||
|
@ -336,6 +336,16 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
|
||||
hooks = None
|
||||
if self.host_call is not None:
|
||||
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
|
||||
if tensor_tracer.TensorTracer.is_enabled():
|
||||
tt = tensor_tracer.TensorTracer()
|
||||
tracing_calls = tt.trace_cpu(ops.get_default_graph())
|
||||
tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls)
|
||||
tracing_functions = tracing_call_ret.values()
|
||||
if tracing_functions:
|
||||
if hooks:
|
||||
hooks.extend([_OutfeedHostCallHook(tracing_functions)])
|
||||
else:
|
||||
hooks = [_OutfeedHostCallHook(tracing_functions)]
|
||||
hooks = tuple(hooks or [])
|
||||
scaffold = self.scaffold_fn() if self.scaffold_fn else None
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
|
Loading…
x
Reference in New Issue
Block a user