Enhance the Tensor-Tracer in the following ways:

(1) Able to trace tensors when the model is executed on the CPU.
    (previously, it could only trace when the model is executed on TPU)
(2) Allow the user to specify the op-names and op-types to be excluded or included for tracing via regular expressions.
(3) Two new trace modes: (1) tracing the vector norm of the tensor and (2) tracing the maximum of the absolute values of all elements in the tensor.
(4) Attach the replica-ID to a traced tensor value so that the post-processing tool (Tensor-Inspector) can reconstruct the whole tensor from all replicas.
(5) An API to trace tensors programmatically.
(6) Allow writing the trace to stdout (previously, it must be written to a file).

PiperOrigin-RevId: 225112219
This commit is contained in:
A. Unique TensorFlower 2018-12-11 19:09:14 -08:00 committed by TensorFlower Gardener
parent 4b974cf1c1
commit 8ac99aa0ec
2 changed files with 487 additions and 78 deletions

View File

@ -21,44 +21,56 @@ from __future__ import print_function
import os
import os.path
import re
import sys
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
_DEVICE_TYPE_TPU = 'tpu'
_DEVICE_TYPE_CPU = 'cpu'
_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP'
_TRACE_MODE_NAN_INF = 'nan-inf'
_TRACE_MODE_PART_TENSOR = 'part-tensor'
_TRACE_MODE_PART_TENSOR_SIZE = 3
_TRACE_MODE_FULL_TENSOR = 'full-tensor'
_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace'
_RECORD_FILTERED_OUT = 'not-traced-filtered-out'
_RECORD_SCALAR = 'not-traced-scalar'
_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
_RECORD_GET_TRACED = 'get-traced'
_TRACE_MODE_NORM = 'norm'
_TRACE_MODE_MAX_ABS = 'max-abs'
_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar'
_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op'
_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch'
_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
_REASON_SCALAR_GET_TRACED = 'traced-scalar'
_REASON_TENSOR_GET_TRACED = 'traced-tensor'
_REASON_USER_INCLUDED = 'traced-user-included'
_REASON_USER_EXCLUDED = 'not-traced-user-excluded'
_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
_MARKER_SECTION_END = '!!!!!!! section-end:'
_SECTION_NAME_CONFIG = 'configuration'
_SECTION_NAME_REASON = 'reason'
_SECTION_NAME_OP_LIST = 'op-list'
_SECTION_NAME_TENSOR_LIST = 'tensor-list'
_SECTION_NAME_GRAPH = 'graph'
_FIELD_NAME_VERSION = 'version:'
_FIELD_NAME_DEVICE = 'device:'
_FIELD_NAME_TRACE_MODE = 'trace-mode:'
_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
_FIELD_NAME_NUM_OPS = 'number-of-ops:'
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
@ -66,13 +78,72 @@ _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"')
_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)')
_FLAG_NAME_ENABLE = 'enable'
_FLAG_NAME_TRACE_MODE = 'trace_mode'
_FLAG_NAME_INTERESTING_OPS = 'interesting_ops'
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops'
_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
_FLAG_NAME_TRACE_FILE = 'trace_file_path'
_FLAG_NAME_REPORT_FILE = 'report_file_path'
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
_FLAG_NAME_OP_RANGE = 'op_range'
_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
_OUTPUT_STREAM_ESCAPE = 'file://'
_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
def tensor_checkpoint(tensor, checkpoint_name):
"""Adds a checkpoint with the given checkpoint name for the given tensor.
The tensor will be added to the list of tensors that will be traced by the
tensor tracer.
Args:
tensor: the tensor object for which the tracing is requested.
checkpoint_name: a string name for the checkpoint. This name has to be a
unique name if used within model comparison. The tensors that have the same
checkpoint identifier is compared in model comparison.
Returns:
The provided tensor.
"""
tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION)
tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION,
(tensor, checkpoint_name))
return tensor
def keras_layer_checkpoint(layer, checkpoint_name):
"""An interface for adding the tensor outputs of a keras layer.
Encapsulates tensor_checkpoint.
Args:
layer: A keras layer.
checkpoint_name: a string name for the checkpoint. This name has to be a
unique name if used within model comparison. The tensors that have the same
checkpoint identifier is compared in model comparison.
Returns:
The provided layer.
"""
try:
outputs = layer.output
if tensor_util.is_tensor(outputs):
tensor_checkpoint(outputs, '%s' % (checkpoint_name))
else:
idx = 0
for output_tensor in outputs:
if tensor_util.is_tensor(outputs):
tensor_checkpoint(output_tensor, '%s_%d' % (checkpoint_name, idx))
idx += 1
except AttributeError:
pass
except RuntimeError:
pass
return layer
class TensorTracer(object):
@ -105,6 +176,34 @@ class TensorTracer(object):
match = _FLAG_NO_QUOTE_PAT.match(flags, pos)
return match
@staticmethod
def validate_flag_names():
"""Validates if the TensorTrace flags passed are valid."""
valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE,
_FLAG_NAME_EXCLUDED_OPNAMES,
_FLAG_NAME_EXCLUDED_OPTYPES,
_FLAG_NAME_INCLUDED_OPNAMES,
_FLAG_NAME_INCLUDED_OPTYPES,
_FLAG_NAME_TRACE_FILE, _FLAG_NAME_REPORT_FILE,
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS,
_FLAG_NAME_OP_RANGE]
tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR)
if not tensor_tracer_flags:
return
pos = 0
while True:
match = TensorTracer._match_next_flag(tensor_tracer_flags, pos)
if not match:
break
flag_name = match.group(1)
if flag_name not in valid_flag_names:
raise ValueError(
'The flag name "%s" passed via the environment variable "%s" '
'is invalid. Valid flag names are:'
'\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names))
pos = match.end()
@staticmethod
def print_flag_values():
"""Prints all TensorTracer flags passed via environment variables."""
@ -146,6 +245,20 @@ class TensorTracer(object):
pos = match.end()
return ''
@staticmethod
def flag_value_to_re_list(flag_name):
"""Converts list of strings to compiled RE."""
re_list = []
flag_value = TensorTracer.get_flag_value(flag_name)
if not flag_value:
return re_list
list_of_values = flag_value.split()
for v in list_of_values:
r = re.compile(v)
re_list.append(r)
return re_list
@staticmethod
def is_enabled():
"""Returns True if TensorTracer is enabled."""
@ -186,29 +299,67 @@ class TensorTracer(object):
"""Checks if the given trace mode is valid."""
valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR,
_TRACE_MODE_FULL_TENSOR]
_TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM,
_TRACE_MODE_MAX_ABS]
if trace_mode not in valid_trace_modes:
raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
'Valid trace modes are: %s'%(trace_mode,
valid_trace_modes))
@staticmethod
def should_trace(device_type, op):
"""Returns True if the given Op should be traced."""
def unsafe_op(op):
"""Returns True if this op is not safe to be traced."""
if device_type != _DEVICE_TYPE_TPU:
raise ValueError('Non TPU device type is not supported')
if control_flow_util.IsInCond(op):
return True
# Reasons for not including following op types:
# Assign: cause incorrect result with CPU tracing.
# others: compilation problems.
if op.type in ['Assign', 'Pack', 'Shape', 'Reshape', 'ArgMin', 'ArgMax']:
return True
return False
@staticmethod
def device_mismatch(device_type, op):
if device_type == _DEVICE_TYPE_TPU:
# pylint: disable=protected-access
return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr
# pylint: enable=protected-access
return False
@staticmethod
def unsafe_scalar_trace(op):
"""Return true if scalar output tensor from Op is not safe to be traced."""
# Tracing the following causes cycle in the graph on TPU.
if op.type in ['LoopCond', 'Enter', 'Merge', 'Const',
'Switch', 'Less', 'ReadVariableOp']:
return True
# Tracing the following will cause casting-issue
# with the norm tracing mode or other compilation issues on CPU.
if op.type in ['VarHandleOp', 'IteratorToStringHandle',
'IteratorGetNext', 'OneShotIterator',
'IteratorV2', 'MakeIterator',
'BatchDatasetV2', 'MapDataset',
'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']:
return True
return False
@staticmethod
def less_interesting_op(op):
"""Returns True if the given Op is not an interesting one to be traced."""
include_less_interesting = TensorTracer.get_flag_value(
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS)
if include_less_interesting:
return False
if op.type in ['Reshape', 'ArgMin', 'ArgMax']:
return False
# pylint: disable=protected-access
return tpu._TPU_REPLICATE_ATTR in op.node_def.attr
# pylint: enable=protected-access
return op.type in ['Const', 'Identity', 'Cast', 'Shape']
@staticmethod
def reason(op_idx, details):
"""Returns why the Op at op_idx is traced or not."""
"""Returns reason why the Op at op_idx is traced or not."""
return '%d %s'%(op_idx, details)
@staticmethod
@ -274,6 +425,33 @@ class TensorTracer(object):
assert len(unsorted_ops) == len(sorted_ops)
return (True, sorted_ops)
@staticmethod
def _make_op_and_tensor_maps(op_list):
"""Creates various maps and lists from op_list.
Args:
op_list: a list of Ops
Returns:
opname_idx_map: a map from Op's name to its index in op_list.
tensor_list: a list of output tensors of the Ops in op_list.
tensorname_idx_map: a map from output tensor name to its index
in tensor_list.
"""
opname_idx_map = {}
tensor_list = []
tensorname_idx_map = {}
for op_id, op in enumerate(op_list):
if op.name in opname_idx_map:
raise ValueError('Duplicated Op name: %s'%op.name)
opname_idx_map[op.name] = op_id
for output_tensor in op.outputs:
if output_tensor.name not in tensorname_idx_map:
tensor_list.append(output_tensor)
tensorname_idx_map[output_tensor.name] = len(tensor_list)-1
return (opname_idx_map, tensor_list, tensorname_idx_map)
def __init__(self):
"""Initializes a TensorTracer.
@ -281,16 +459,20 @@ class TensorTracer(object):
"""
self._version = 'use-outside-compilation'
self._device_type = None
TensorTracer.validate_flag_names()
self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE)
if not self._trace_mode:
self._trace_mode = _TRACE_MODE_NAN_INF
TensorTracer.check_trace_mode(self._trace_mode)
self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE
self._instrument_records = {}
interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS)
self._selected_ops = interesting_ops.split()
self._set_trace_file_path()
self._set_report_file()
self._set_op_range()
self._set_excluded_opnames()
self._set_excluded_optypes()
self._set_included_opnames()
self._set_included_optypes()
self._num_replicas = None
self._replica_id = None
@ -318,10 +500,7 @@ class TensorTracer(object):
"""Sets the path of the output trace file."""
self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE)
if not self._trace_file_path:
raise ValueError('--%s is not set in the environment variable %s'
%(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR))
elif TensorTracer.use_test_undeclared_outputs_dir():
if self._trace_file_path and TensorTracer.use_test_undeclared_outputs_dir():
if os.path.isabs(self._trace_file_path):
raise ValueError('If use_test_undeclared_outputs_dir is set,'
'trace_file_path cannot be an absolute path (%s)'
@ -330,6 +509,22 @@ class TensorTracer(object):
self._trace_file_path = os.path.join(outputs_dir,
self._trace_file_path)
def _set_report_file(self):
"""Sets the path of the output report file."""
self._report_file_path = TensorTracer.get_flag_value(_FLAG_NAME_REPORT_FILE)
if not self._report_file_path:
self._report_file = None
return
try:
self._report_file = gfile.Open(self._report_file_path, 'w')
except IOError as e:
raise e
def _close_report_file(self):
if self._report_file:
self._report_file.close()
def _set_op_range(self):
"""Sets the index range of the Ops that we will consider tracing."""
@ -350,19 +545,48 @@ class TensorTracer(object):
return False
return self._op_range[1] < 0 or idx <= self._op_range[1]
def _set_excluded_opnames(self):
self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list(
_FLAG_NAME_EXCLUDED_OPNAMES)
def _set_excluded_optypes(self):
self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list(
_FLAG_NAME_EXCLUDED_OPTYPES)
def _set_included_opnames(self):
self._included_opname_re_list = TensorTracer.flag_value_to_re_list(
_FLAG_NAME_INCLUDED_OPNAMES)
def _set_included_optypes(self):
self._included_optype_re_list = TensorTracer.flag_value_to_re_list(
_FLAG_NAME_INCLUDED_OPTYPES)
def _is_user_included_op(self, op):
for opname_re in self._included_opname_re_list:
if opname_re.match(op.name):
return True
for optype_re in self._included_optype_re_list:
if optype_re.match(op.type):
return True
return False
def _is_user_excluded_op(self, op):
for opname_re in self._excluded_opname_re_list:
if opname_re.match(op.name):
return True
for optype_re in self._excluded_optype_re_list:
if optype_re.match(op.type):
return True
return False
def _write_report(self, content):
"""Writes the given content to the report."""
logging.info('%s %s'%(_TRACER_LOG_PREFIX, content))
def _is_selected_op(self, op_name):
"""Returns True if the Op with op_name is selected to be traced."""
if not self._selected_ops:
return True
if op_name in self._selected_ops:
return True
return False
line = '%s %s'%(_TRACER_LOG_PREFIX, content)
if self._report_file:
self._report_file.write(line)
else:
logging.info(line)
def _write_config_section(self):
"""Writes the config section of the report."""
@ -382,15 +606,42 @@ class TensorTracer(object):
self._write_report('"%s" %s\n'%(key, self._instrument_records[key]))
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
def _write_op_list_section(self, op_list):
def _write_op_list_section(self, op_list, tensorname_idx_map):
"""Writes the Op-list section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list)))
for i in range(0, len(op_list)):
self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type))
op = op_list[i]
line = '%d "%s" %s'%(i, op.name, op.type)
for out_tensor in op.outputs:
if out_tensor.name not in tensorname_idx_map:
raise ValueError(
'out_tensor %s is not in tensorname_idx_map'%out_tensor.name)
line += ' %d'%tensorname_idx_map[out_tensor.name]
line += '\n'
self._write_report(line)
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
def _write_tensor_list_section(self, tensor_list, opname_idx_map):
"""Writes the tensor-list section of the report."""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
_SECTION_NAME_TENSOR_LIST))
self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list)))
for i in range(0, len(tensor_list)):
tensor = tensor_list[i]
line = '%d "%s"'%(i, tensor.name)
for consumer_op in tensor.consumers():
if consumer_op.name not in opname_idx_map:
raise ValueError(
'consumer_op %s is not in opname_idx_map'%consumer_op.name)
line += ' %d'%opname_idx_map[consumer_op.name]
line += '\n'
self._write_report(line)
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
_SECTION_NAME_TENSOR_LIST))
def _write_graph_section(self, succeed, sorted_or_cycle):
"""Writes the graph section of the report."""
@ -422,7 +673,7 @@ class TensorTracer(object):
Args:
op_name: the name of the Op that outputs the tensor to be printed.
output_idx: which output of the Op it is (0 means the first output).
num_elements: number of elements to print.
num_elements: number of elements to print (-1 means print all).
tensor: the tensor needs to be returned.
output_tensor: the tensor needs to be printed.
@ -430,10 +681,13 @@ class TensorTracer(object):
The same tensor passed via the "tensor" argument.
"""
msg = '"%s:%d" '%(op_name, output_idx)
output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path
if self._trace_file_path:
output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path
else:
output_stream = sys.stderr
print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor),
' @', self._replica_id,
'\n', output_tensor,
'\n', output_tensor, '\n',
summarize=num_elements,
output_stream=output_stream)
with ops.control_dependencies([print_op]):
@ -442,7 +696,8 @@ class TensorTracer(object):
def _detect_nan_inf(tensor):
"""Trace function for detecting any NaN/Inf in the tensor."""
if tensor.dtype.is_floating:
if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
dtypes.float16):
# Since host can't handle bf16, always convert tensor to f32.
tensor = math_ops.cast(tensor, dtypes.float32)
output_tensor = math_ops.reduce_any(
@ -450,12 +705,19 @@ class TensorTracer(object):
gen_math_ops.is_inf(tensor)))
else:
output_tensor = constant_op.constant(0)
return _print_tensor(op_name, output_idx, 1, tensor, output_tensor)
return _print_tensor(op_name, output_idx, -1, tensor, output_tensor)
def _show_global_step(tensor):
"""Trace function for printing the global step count."""
def _show_norm(tensor):
tensor = math_ops.cast(tensor, dtypes.float64)
output_tensor = linalg_ops.norm(tensor)
return _print_tensor(op_name, output_idx, -1, tensor, output_tensor)
return _print_tensor(op_name, output_idx, 1, tensor, tensor)
def _show_max_abs(tensor):
output_tensor = math_ops.cast(math_ops.reduce_max(math_ops.abs(tensor)),
dtypes.float64)
zero = constant_op.constant(0, dtypes.float64)
output_tensor = gen_math_ops.maximum(zero, output_tensor)
return _print_tensor(op_name, output_idx, -1, tensor, output_tensor)
def _show_part_tensor(tensor):
"""Trace function for printing part of the tensor."""
@ -468,23 +730,139 @@ class TensorTracer(object):
return _print_tensor(op_name, output_idx, -1, tensor, tensor)
if op_name == _GLOBAL_STEP_OP_NAME:
return _show_global_step
if self._trace_mode == _TRACE_MODE_NAN_INF:
return _detect_nan_inf
if self._trace_mode == _TRACE_MODE_PART_TENSOR:
return _show_part_tensor
if self._trace_mode == _TRACE_MODE_FULL_TENSOR:
return _show_full_tensor
if self._trace_mode == _TRACE_MODE_NORM:
return _show_norm
if self._trace_mode == _TRACE_MODE_MAX_ABS:
return _show_max_abs
raise RuntimeError('Tensor trace fun for %s is not yet implemented'
%self._trace_mode)
def _skip_op(self, op_id, op, user_included, user_excluded):
"""Returns True if we should not trace Op."""
if user_included:
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_USER_INCLUDED)
return False
if user_excluded:
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_USER_EXCLUDED)
return True
if not self._inside_op_range(op_id):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_OUTSIDE_OP_RANGE)
return True
if TensorTracer.unsafe_op(op):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_UNSAFE_OP)
return True
if TensorTracer.device_mismatch(self._device_type, op):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_DEVICE_MISMATCH)
return True
if TensorTracer.less_interesting_op(op):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _REASON_LESS_INTERESTING_OP)
return True
return False
def _skip_tensor(self, op_id, out_tensor, user_included,
user_excluded):
"""Returns True if we should not trace out_tensor."""
# Skips a tensor if the tensor has a non-numeric type.
# Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
# because it also excludes tensors with dtypes, bool, and
# float32_ref, which we actually want to trace.
non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
dtypes.string])
if out_tensor.dtype in non_numeric_tensor_types:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_NON_NUMERIC_TENSOR)
return True
if user_included:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_USER_INCLUDED)
return False
if user_excluded:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_USER_EXCLUDED)
return True
if not out_tensor.get_shape().is_fully_defined():
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_DYNAMIC_SHAPE)
return True
rank = len(out_tensor.shape)
if rank < 1:
# scalar
if TensorTracer.unsafe_scalar_trace(out_tensor.op):
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_UNSAFE_SCALAR)
return True
else:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_SCALAR_GET_TRACED)
return False
else:
# tensor
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _REASON_TENSOR_GET_TRACED)
return False
def _pre_tracing(self, graph):
"""Work needs to be done prior to TPU or CPU tracing."""
operations = graph.get_operations()
(opname_idx_map, tensor_list, tensorname_idx_map) = (
TensorTracer._make_op_and_tensor_maps(operations))
self._write_config_section()
self._write_op_list_section(operations, tensorname_idx_map)
self._write_tensor_list_section(tensor_list, opname_idx_map)
# Does the topological sort before adding any nodes to the graph.
(succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
return (operations, succeed, sorted_or_cycle)
def _post_tracing(self, succeed, sorted_or_cycle):
"""Work needs to be done after TPU or CPU tracing."""
self._write_reason_section()
self._write_graph_section(succeed, sorted_or_cycle)
self._close_report_file()
def _get_checkpoints(self, graph):
"""Returns the list of Ops that produce the tensors traced with API.
Args:
graph: the graph of Ops.
Returns:
A set of operation names which should be traced.
"""
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
_TENSOR_TRACER_CHECKPOINT))
checkpoint_operations = set()
tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION)
for (tensor, checkpoint_name) in tensor_tracer_variables:
self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
checkpoint_operations.add(tensor.op.name)
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
_TENSOR_TRACER_CHECKPOINT))
return checkpoint_operations
def trace_tpu(self, graph, result_tensor, num_replicas=None):
"""Traces the tensors generated by TPU Ops in a TF graph.
Args:
graph: the graph of Ops.
graph: the graph of Ops executed on the TPU.
result_tensor: a result tensor of evaluating the graph.
num_replicas: number of replicas used on the TPU.
@ -502,38 +880,22 @@ class TensorTracer(object):
TensorTracer.check_device_type(self._device_type)
result_tensor_copy = self._add_replica_id_to_graph(num_replicas,
result_tensor)
self._write_config_section()
(operations, succeed, sorted_or_cycle) = self._pre_tracing(graph)
tracing_ops = []
operations = graph.get_operations()
self._write_op_list_section(operations)
# Does the topological sort before adding any nodes to the graph.
(succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
checkpoint_operations = self._get_checkpoints(graph)
for op_id, op in enumerate(operations):
if not self._inside_op_range(op_id):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _RECORD_OUTSIDE_OP_RANGE)
if checkpoint_operations and op.name not in checkpoint_operations:
continue
if not TensorTracer.should_trace(self._device_type, op):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _RECORD_SHOULD_NOT_TRACE)
continue
if not self._is_selected_op(op.name):
self._instrument_records[op.name] = TensorTracer.reason(
op_id, _RECORD_FILTERED_OUT)
user_included = self._is_user_included_op(op)
user_excluded = self._is_user_excluded_op(op)
if self._skip_op(op_id, op, user_included, user_excluded):
continue
for i in range(len(op.outputs)):
out_tensor = op.outputs[i]
if not out_tensor.get_shape().is_fully_defined():
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _RECORD_DYNAMIC_SHAPE)
continue # cannot trace tensors with dynamic shape.
rank = len(out_tensor.shape)
if rank < 1:
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _RECORD_SCALAR)
continue # cannot trace scalar.
self._instrument_records[out_tensor.name] = TensorTracer.reason(
op_id, _RECORD_GET_TRACED)
if self._skip_tensor(op_id, out_tensor, user_included,
user_excluded):
continue
consumers = out_tensor.consumers()
trace_op = tpu.outside_compilation(
self._make_tensor_trace_fun(op.name, i), out_tensor)
@ -546,8 +908,45 @@ class TensorTracer(object):
# if there is no consumer, we will add the control dependence later
# when we add the control dependency to the output operations.
tracing_ops.append(trace_op)
self._write_reason_section()
self._write_graph_section(succeed, sorted_or_cycle)
self._post_tracing(succeed, sorted_or_cycle)
return (result_tensor_copy, tracing_ops)
def trace_cpu(self, graph):
"""Traces the tensors generated by CPU Ops in a TF graph.
Args:
graph: the graph of Ops executed on the CPU.
Returns:
tracing_calls: a map from keys to trace calls.
A key is constructed from an Op's name.
A trace call consists of a function and a tensor (
the function will be invoked with the tensor).
"""
self._device_type = _DEVICE_TYPE_CPU
TensorTracer.check_device_type(self._device_type)
self._num_replicas = 1
self._replica_id = 0
(operations, succeed, sorted_or_cycle) = self._pre_tracing(graph)
tracing_calls = {}
checkpoint_operations = self._get_checkpoints(graph)
for op_id, op in enumerate(operations):
if checkpoint_operations and op.name not in checkpoint_operations:
continue
user_included = self._is_user_included_op(op)
user_excluded = self._is_user_excluded_op(op)
if self._skip_op(op_id, op, user_included, user_excluded):
continue
for i in range(len(op.outputs)):
out_tensor = op.outputs[i]
if self._skip_tensor(op_id, out_tensor, user_included,
user_excluded):
continue
trace_fun = self._make_tensor_trace_fun(op.name, i)
trace_call = (trace_fun, [out_tensor])
trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i)
tracing_calls[trace_call_key] = trace_call
self._post_tracing(succeed, sorted_or_cycle)
return tracing_calls

View File

@ -336,6 +336,16 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
hooks = None
if self.host_call is not None:
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
if tensor_tracer.TensorTracer.is_enabled():
tt = tensor_tracer.TensorTracer()
tracing_calls = tt.trace_cpu(ops.get_default_graph())
tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls)
tracing_functions = tracing_call_ret.values()
if tracing_functions:
if hooks:
hooks.extend([_OutfeedHostCallHook(tracing_functions)])
else:
hooks = [_OutfeedHostCallHook(tracing_functions)]
hooks = tuple(hooks or [])
scaffold = self.scaffold_fn() if self.scaffold_fn else None
return model_fn_lib.EstimatorSpec(