Tensor tracer refactor.
PiperOrigin-RevId: 248544325
This commit is contained in:
parent
f0836d2a3b
commit
ef8dc7d153
@ -154,6 +154,7 @@ py_library(
|
||||
"device_assignment.py",
|
||||
"session_support.py",
|
||||
"tensor_tracer.py",
|
||||
"tensor_tracer_flags.py",
|
||||
"topology.py",
|
||||
"tpu.py",
|
||||
"tpu_feed.py",
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
import sys
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -40,22 +39,14 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import tensor_tracer_flags
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
|
||||
_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
|
||||
_DEVICE_TYPE_TPU = 'tpu'
|
||||
_DEVICE_TYPE_CPU = 'cpu'
|
||||
_TRACE_MODE_NAN_INF = 'nan-inf'
|
||||
_TRACE_MODE_PART_TENSOR = 'part-tensor'
|
||||
_TRACE_MODE_PART_TENSOR_SIZE = 3
|
||||
_TRACE_MODE_FULL_TENSOR = 'full-tensor'
|
||||
_TRACE_MODE_FULL_IF_NAN = 'trace-back-if-nan'
|
||||
_FLAG_NAME_TRACE_STACK_SIZE = 'trace_stack_size'
|
||||
_TRACE_MODE_NORM = 'norm'
|
||||
_TRACE_MODE_MAX_ABS = 'max-abs'
|
||||
_SUBMODE_BRIEF = 'brief'
|
||||
_SUBMODE_DETAILED = 'detailed'
|
||||
_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
|
||||
_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
|
||||
_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
|
||||
@ -90,34 +81,7 @@ _FIELD_NAME_NUM_OPS = 'number-of-ops:'
|
||||
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
|
||||
_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
|
||||
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
|
||||
_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
|
||||
_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
|
||||
_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"')
|
||||
_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)')
|
||||
_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*')
|
||||
_FLAG_NAME_ENABLE = 'enable'
|
||||
_FLAG_NAME_TRACE_MODE = 'trace_mode'
|
||||
_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace'
|
||||
_FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar'
|
||||
_FLAG_NAME_TRACE_BEFORE_OPS = 'trace_before_included_ops'
|
||||
_FLAG_NAME_TRACE_AFTER_OPS = 'trace_after_included_ops'
|
||||
_FLAG_NAME_SUBMODE = 'submode'
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops'
|
||||
_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
|
||||
_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
|
||||
_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
|
||||
_FLAG_NAME_INCLUDED_CORES = 'included_cores'
|
||||
_FLAG_NAME_TRACE_DIR = 'trace_dir'
|
||||
_FLAG_NAME_REPORT_FILE = 'report_file'
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
|
||||
_FLAG_NAME_OP_RANGE = 'op_range'
|
||||
# Folder to dump the pre (before tensor tracer updates) and post graphs (after
|
||||
# tensor tracer updates).
|
||||
_FLAG_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs'
|
||||
_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
|
||||
_OUTPUT_STREAM_ESCAPE = 'file://'
|
||||
_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
|
||||
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
|
||||
_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
|
||||
_TRACE_FILE_NAME = 'trace.all'
|
||||
@ -227,7 +191,6 @@ def _create_tensor_values_cache(graph, num_tensors):
|
||||
use_resource=True,
|
||||
collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
|
||||
|
||||
|
||||
class TensorTracer(object):
|
||||
"""A software construct for tracing tensor values in a TF graph on TPU.
|
||||
|
||||
@ -247,233 +210,10 @@ class TensorTracer(object):
|
||||
# The set of graphs that are rewritten by tensor tracer.
|
||||
_traced_graphs = set()
|
||||
|
||||
@staticmethod
|
||||
def _match_next_flag(flags, pos):
|
||||
"""Returns the match for the next TensorTracer flag.
|
||||
|
||||
Args:
|
||||
flags: a string that contains the flags.
|
||||
pos: where in flags to start the search.
|
||||
|
||||
Returns:
|
||||
A pair where the first element is the regular-expression
|
||||
match found and the second element indicates if the match
|
||||
has a value.
|
||||
"""
|
||||
|
||||
match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos)
|
||||
if match:
|
||||
return match, True
|
||||
match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos)
|
||||
if match:
|
||||
return match, True
|
||||
match = _FLAG_NO_QUOTE_PAT.match(flags, pos)
|
||||
if match:
|
||||
return match, True
|
||||
match = _FLAG_NO_EQUAL_PAT.match(flags, pos)
|
||||
if match:
|
||||
# The flag is found but is not given a value.
|
||||
return match, False
|
||||
# The flag is not found.
|
||||
return None, False
|
||||
|
||||
@staticmethod
|
||||
def validate_flag_names():
|
||||
"""Validates if the TensorTrace flags passed are valid."""
|
||||
valid_flag_names = [
|
||||
_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, _FLAG_NAME_USE_COMPACT_TRACE,
|
||||
_FLAG_NAME_TRACE_SCALAR_OPS, _FLAG_NAME_TRACE_BEFORE_OPS,
|
||||
_FLAG_NAME_TRACE_AFTER_OPS, _FLAG_NAME_TRACE_STACK_SIZE,
|
||||
_FLAG_NAME_SUBMODE, _FLAG_NAME_EXCLUDED_OPNAMES,
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES, _FLAG_NAME_INCLUDED_OPNAMES,
|
||||
_FLAG_NAME_INCLUDED_OPTYPES, _FLAG_NAME_TRACE_DIR,
|
||||
_FLAG_NAME_INCLUDED_CORES, _FLAG_NAME_REPORT_FILE,
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, _FLAG_NAME_OP_RANGE,
|
||||
_FLAG_DUMP_BEFORE_AFTER_GRAPHS
|
||||
]
|
||||
tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR)
|
||||
if not tensor_tracer_flags:
|
||||
return
|
||||
pos = 0
|
||||
while True:
|
||||
match, _ = TensorTracer._match_next_flag(tensor_tracer_flags, pos)
|
||||
if not match:
|
||||
break
|
||||
flag_name = match.group(1)
|
||||
if flag_name not in valid_flag_names:
|
||||
raise ValueError(
|
||||
'The flag name "%s" passed via the environment variable "%s" '
|
||||
'is invalid. Valid flag names are:'
|
||||
'\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names))
|
||||
pos = match.end()
|
||||
|
||||
@staticmethod
|
||||
def print_flag_values():
|
||||
"""Prints all TensorTracer flags passed via environment variables."""
|
||||
|
||||
tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR)
|
||||
if not tensor_tracer_flags:
|
||||
return 'Env variable "%s" is not set'%_FLAGS_ENV_VAR
|
||||
result = 'Env variable "%s" is set to "%s"\n'%(_FLAGS_ENV_VAR,
|
||||
tensor_tracer_flags)
|
||||
result += 'Individual flag value:\n'
|
||||
pos = 0
|
||||
while True:
|
||||
match, has_value = TensorTracer._match_next_flag(
|
||||
tensor_tracer_flags, pos)
|
||||
if not match:
|
||||
break
|
||||
flag_name = match.group(1)
|
||||
if has_value:
|
||||
flag_value = match.group(2)
|
||||
else:
|
||||
flag_value = None
|
||||
result += ' %s: %s\n'%(flag_name, flag_value)
|
||||
pos = match.end()
|
||||
result += '\n'
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def flag_value_as_int_list(wanted_flag_name):
|
||||
"""Returns the integer list of a TensorTracer flag.
|
||||
|
||||
Args:
|
||||
wanted_flag_name: the name of the flag we are looking for.
|
||||
|
||||
Returns:
|
||||
the value of the flag.
|
||||
Raises:
|
||||
RuntimeError: If supposedly deadcode is reached.
|
||||
"""
|
||||
int_list = []
|
||||
found, flag_value = TensorTracer.get_flag_value(wanted_flag_name)
|
||||
|
||||
if found:
|
||||
try:
|
||||
integer_values = flag_value.split(',')
|
||||
int_list = [int(int_val) for int_val in integer_values]
|
||||
except ValueError:
|
||||
logging.warning('Cannot convert %s to int for flag %s', int_list,
|
||||
wanted_flag_name)
|
||||
return int_list
|
||||
|
||||
@staticmethod
|
||||
def get_flag_int_value(wanted_flag_name, default_value):
|
||||
"""Returns the int value of a TensorTracer flag.
|
||||
|
||||
Args:
|
||||
wanted_flag_name: the name of the flag we are looking for.
|
||||
default_value: the default value for the flag, if not provided.
|
||||
Returns:
|
||||
the value of the flag.
|
||||
Raises:
|
||||
RuntimeError: If supposedly deadcode is reached.
|
||||
"""
|
||||
flag_int_value = default_value
|
||||
found, flag_value = TensorTracer.get_flag_value(wanted_flag_name)
|
||||
|
||||
if found:
|
||||
try:
|
||||
flag_int_value = int(flag_value)
|
||||
except ValueError:
|
||||
logging.warning('Cannot convert %s to int for flag %s' % (
|
||||
flag_int_value, wanted_flag_name))
|
||||
return flag_int_value
|
||||
|
||||
@staticmethod
|
||||
def get_flag_value(wanted_flag_name):
|
||||
"""Returns the value of a TensorTracer flags.
|
||||
|
||||
Args:
|
||||
wanted_flag_name: the name of the flag we are looking for.
|
||||
|
||||
Returns:
|
||||
A pair where the first element indicates if the flag is
|
||||
found and the second element is the value of the flag.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If supposedly deadcode is reached.
|
||||
"""
|
||||
|
||||
tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR)
|
||||
if not tensor_tracer_flags:
|
||||
return False, None
|
||||
pos = 0
|
||||
while True:
|
||||
match, has_value = TensorTracer._match_next_flag(
|
||||
tensor_tracer_flags, pos)
|
||||
if not match:
|
||||
return False, None
|
||||
flag_name = match.group(1)
|
||||
if has_value:
|
||||
flag_value = match.group(2)
|
||||
else:
|
||||
flag_value = None
|
||||
if flag_name == wanted_flag_name:
|
||||
return True, flag_value
|
||||
pos = match.end()
|
||||
raise RuntimeError('Should not reach here.')
|
||||
|
||||
@staticmethod
|
||||
def flag_value_to_re_list(flag_name):
|
||||
"""Converts list of strings to compiled RE."""
|
||||
|
||||
re_list = []
|
||||
found, flag_value = TensorTracer.get_flag_value(flag_name)
|
||||
if not found or not flag_value:
|
||||
return re_list
|
||||
list_of_values = flag_value.split()
|
||||
for v in list_of_values:
|
||||
r = re.compile(v)
|
||||
re_list.append(r)
|
||||
return re_list
|
||||
|
||||
@staticmethod
|
||||
def _is_flag_on(flag_name):
|
||||
"""Returns True if the given flag is on."""
|
||||
|
||||
found, flag_value = TensorTracer.get_flag_value(flag_name)
|
||||
if not found:
|
||||
return False
|
||||
if flag_value is None:
|
||||
return True
|
||||
# Depends on the flag value.
|
||||
flag_value = flag_value.lower()
|
||||
enabled = flag_value in ['1', 't', 'true', 'y', 'yes']
|
||||
return enabled
|
||||
|
||||
@staticmethod
|
||||
def is_enabled():
|
||||
"""Returns True if TensorTracer is enabled."""
|
||||
|
||||
if TensorTracer._is_flag_on(_FLAG_NAME_ENABLE):
|
||||
logging.info('Tensor Tracer is enabled with flags %s.' %
|
||||
os.getenv(_FLAGS_ENV_VAR))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def use_test_undeclared_outputs_dir():
|
||||
"""Decides the output directory of the report and trace files.
|
||||
|
||||
Args:
|
||||
None.
|
||||
|
||||
Returns:
|
||||
True if the output files should be written to the
|
||||
test-undeclared-outputs-directory defined via an
|
||||
env variable.
|
||||
"""
|
||||
|
||||
return TensorTracer._is_flag_on(
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)
|
||||
|
||||
@staticmethod
|
||||
def use_compact_trace():
|
||||
return TensorTracer._is_flag_on(
|
||||
_FLAG_NAME_USE_COMPACT_TRACE)
|
||||
return tensor_tracer_flags.TTParameters().is_enabled()
|
||||
|
||||
@staticmethod
|
||||
def check_device_type(device_type):
|
||||
@ -482,31 +222,6 @@ class TensorTracer(object):
|
||||
if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]:
|
||||
raise ValueError('Invalid device_type "%s"'%device_type)
|
||||
|
||||
@staticmethod
|
||||
def check_trace_mode(trace_mode):
|
||||
"""Checks if the given trace mode is valid."""
|
||||
|
||||
valid_trace_modes = [
|
||||
_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, _TRACE_MODE_FULL_TENSOR,
|
||||
_TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS, _TRACE_MODE_FULL_IF_NAN
|
||||
]
|
||||
if trace_mode not in valid_trace_modes:
|
||||
raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
|
||||
'Valid trace modes are: %s'%(trace_mode,
|
||||
valid_trace_modes))
|
||||
|
||||
@staticmethod
|
||||
def check_submode(submode):
|
||||
"""Checks if the given submode is valid."""
|
||||
|
||||
if not submode:
|
||||
return
|
||||
valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF]
|
||||
if submode not in valid_submodes:
|
||||
raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.'
|
||||
'Valid submodes are: %s'%(submode,
|
||||
valid_submodes))
|
||||
|
||||
@staticmethod
|
||||
def loop_cond_op(op):
|
||||
return op.type in ('LoopCond', 'RefLoopCond')
|
||||
@ -569,14 +284,10 @@ class TensorTracer(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def less_interesting_op(op):
|
||||
"""Returns True if the given Op is not an interesting one to be traced."""
|
||||
|
||||
found, _ = TensorTracer.get_flag_value(
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS)
|
||||
if found:
|
||||
# users force to include all ops.
|
||||
def _less_interesting_op(self, op):
|
||||
"""Returns True if the given op is not an interesting one to be traced."""
|
||||
# If flag is set to include less interesting ops, then include everything.
|
||||
if self._parameters.include_less_interesting_ops:
|
||||
return False
|
||||
# Following ops are highly unlikey to cause bugs.
|
||||
return op.type in ['Const', 'Identity', 'Cast', 'Shape']
|
||||
@ -680,67 +391,22 @@ class TensorTracer(object):
|
||||
tensorname_idx_map[output_tensor.name] = len(tensor_list)-1
|
||||
return (opname_idx_map, tensor_list, tensorname_idx_map)
|
||||
|
||||
@staticmethod
|
||||
def is_conditional_trace_mode(trace_mode):
|
||||
return trace_mode == _TRACE_MODE_FULL_IF_NAN
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a TensorTracer.
|
||||
|
||||
Sets the various member fields from the flags (if given) or the defaults.
|
||||
"""
|
||||
self._parameters = tensor_tracer_flags.TTParameters()
|
||||
self._set_report_file()
|
||||
self._version = 'use-outside-compilation'
|
||||
self._device_type = None
|
||||
TensorTracer.validate_flag_names()
|
||||
found, self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE)
|
||||
if not found or not self._trace_mode:
|
||||
self._trace_mode = _TRACE_MODE_NAN_INF
|
||||
TensorTracer.check_trace_mode(self._trace_mode)
|
||||
found, self._submode = TensorTracer.get_flag_value(_FLAG_NAME_SUBMODE)
|
||||
if not found or not self._submode:
|
||||
self._submode = _SUBMODE_DETAILED
|
||||
TensorTracer.check_submode(self._submode)
|
||||
self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE
|
||||
self._instrument_records = {}
|
||||
self._set_trace_dir()
|
||||
self._set_report_file()
|
||||
self._set_op_range()
|
||||
self._set_excluded_opnames()
|
||||
self._set_excluded_optypes()
|
||||
self._set_included_opnames()
|
||||
self._set_included_optypes()
|
||||
self._num_replicas = None
|
||||
self._num_replicas_per_host = None
|
||||
self._num_hosts = None
|
||||
self._replica_id = None
|
||||
self._included_op_full_names = set()
|
||||
self._is_conditional_trace = TensorTracer.is_conditional_trace_mode(
|
||||
self._trace_mode)
|
||||
self._trace_scalar_ops = TensorTracer._is_flag_on(
|
||||
_FLAG_NAME_TRACE_SCALAR_OPS)
|
||||
|
||||
# _trace_ops_before_included and _trace_ops_after_included denotes to depth
|
||||
# of tracing relative to the ops given in --included_opnames or
|
||||
# --included_optypes
|
||||
# For example, in the below graph
|
||||
# op1 --> op2 --> op3 --> op4 --> op5
|
||||
# If --included_opnames=op3 then only op3 will be traced.
|
||||
# If also --trace_before_included_ops=2 (_trace_ops_before_included), then
|
||||
# op1 and op2 will be traced as they are at most 2 hops apart from an
|
||||
# included op. Similarly, if --trace_after_included_ops=2, then op4 and op5
|
||||
# will also be traced.
|
||||
self._trace_ops_before_included = TensorTracer.get_flag_int_value(
|
||||
_FLAG_NAME_TRACE_BEFORE_OPS, 0)
|
||||
self._trace_ops_after_included = TensorTracer.get_flag_int_value(
|
||||
_FLAG_NAME_TRACE_AFTER_OPS, 0)
|
||||
self._trace_stack_size = TensorTracer.get_flag_int_value(
|
||||
_FLAG_NAME_TRACE_STACK_SIZE, 1)
|
||||
|
||||
_, self._graph_dump_path = TensorTracer.get_flag_value(
|
||||
_FLAG_DUMP_BEFORE_AFTER_GRAPHS)
|
||||
|
||||
self._included_cores = TensorTracer.flag_value_as_int_list(
|
||||
_FLAG_NAME_INCLUDED_CORES)
|
||||
|
||||
def _add_replica_id_to_graph(self):
|
||||
"""Adds nodes for computing the replica ID to the graph."""
|
||||
@ -754,35 +420,13 @@ class TensorTracer(object):
|
||||
else:
|
||||
self._replica_id = 'unknown'
|
||||
|
||||
def _set_trace_dir(self):
|
||||
found, self._trace_dir = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_DIR)
|
||||
if found and self._trace_dir \
|
||||
and TensorTracer.use_test_undeclared_outputs_dir():
|
||||
raise ValueError('Cannot not use --%s and --%s at the same time'
|
||||
%(_FLAG_NAME_TRACE_DIR,
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR))
|
||||
if TensorTracer.use_test_undeclared_outputs_dir():
|
||||
self._trace_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
|
||||
|
||||
def _set_report_file(self):
|
||||
"""Sets the path of the output report file."""
|
||||
|
||||
found, self._report_file_path = TensorTracer.get_flag_value(
|
||||
_FLAG_NAME_REPORT_FILE)
|
||||
if found and self._report_file_path \
|
||||
and TensorTracer.use_test_undeclared_outputs_dir():
|
||||
if os.path.isabs(self._report_file_path):
|
||||
raise ValueError('If use_test_undeclared_outputs_dir is set,'
|
||||
'report_file_path cannot be an absolute path (%s)'
|
||||
%self._report_file_path)
|
||||
outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
|
||||
self._report_file_path = os.path.join(outputs_dir,
|
||||
self._report_file_path)
|
||||
if not self._report_file_path:
|
||||
if not self._parameters.report_file_path:
|
||||
self._report_file = None
|
||||
return
|
||||
try:
|
||||
self._report_file = gfile.Open(self._report_file_path, 'w')
|
||||
self._report_file = gfile.Open(self._parameters.report_file_path, 'w')
|
||||
except IOError as e:
|
||||
raise e
|
||||
|
||||
@ -790,41 +434,13 @@ class TensorTracer(object):
|
||||
if self._report_file:
|
||||
self._report_file.close()
|
||||
|
||||
def _set_op_range(self):
|
||||
"""Sets the index range of the Ops that we will consider tracing."""
|
||||
|
||||
found, op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE)
|
||||
if not found or not op_range:
|
||||
self._op_range = (-1, -1) # this means including all ops.
|
||||
return
|
||||
match = _OP_RANGE_PAT.match(op_range)
|
||||
if not match:
|
||||
self._op_range = (-1, -1) # this means including all ops.
|
||||
return
|
||||
self._op_range = (int(match.group(1)), int(match.group(2)))
|
||||
|
||||
def _inside_op_range(self, idx):
|
||||
"""Return True if the given index is inside the selected range."""
|
||||
|
||||
if idx < self._op_range[0]:
|
||||
if idx < self._parameters.op_range[0]:
|
||||
return False
|
||||
return self._op_range[1] < 0 or idx <= self._op_range[1]
|
||||
|
||||
def _set_excluded_opnames(self):
|
||||
self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list(
|
||||
_FLAG_NAME_EXCLUDED_OPNAMES)
|
||||
|
||||
def _set_excluded_optypes(self):
|
||||
self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list(
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES)
|
||||
|
||||
def _set_included_opnames(self):
|
||||
self._included_opname_re_list = TensorTracer.flag_value_to_re_list(
|
||||
_FLAG_NAME_INCLUDED_OPNAMES)
|
||||
|
||||
def _set_included_optypes(self):
|
||||
self._included_optype_re_list = TensorTracer.flag_value_to_re_list(
|
||||
_FLAG_NAME_INCLUDED_OPTYPES)
|
||||
return (self._parameters.op_range[1] < 0 or
|
||||
idx <= self._parameters.op_range[1])
|
||||
|
||||
def _is_user_included_op(self, op):
|
||||
"""Checks whether the op is included in the tensor tracer flags.
|
||||
@ -844,12 +460,12 @@ class TensorTracer(object):
|
||||
"""Helper function to check if op is included or not."""
|
||||
if op.name in self._included_op_full_names:
|
||||
return True
|
||||
for opname_re in self._included_opname_re_list:
|
||||
for opname_re in self._parameters.included_opname_re_list:
|
||||
if opname_re.match(op.name):
|
||||
self._included_op_full_names.add(op.name)
|
||||
return True
|
||||
|
||||
for optype_re in self._included_optype_re_list:
|
||||
for optype_re in self._parameters.included_optype_re_list:
|
||||
if optype_re.match(op.type):
|
||||
self._included_op_full_names.add(op.name)
|
||||
return True
|
||||
@ -870,15 +486,15 @@ class TensorTracer(object):
|
||||
return False
|
||||
# check_after and check_before are swapped below, as below operation
|
||||
# checks the distance from an arbitrary op to included ops.
|
||||
return _is_op_or_any_neighbor_included(op,
|
||||
self._trace_ops_after_included,
|
||||
self._trace_ops_before_included)
|
||||
return _is_op_or_any_neighbor_included(
|
||||
op, self._parameters.trace_ops_after_included,
|
||||
self._parameters.trace_ops_before_included)
|
||||
|
||||
def _is_user_excluded_op(self, op):
|
||||
for opname_re in self._excluded_opname_re_list:
|
||||
for opname_re in self._parameters.excluded_opname_re_list:
|
||||
if opname_re.match(op.name):
|
||||
return True
|
||||
for optype_re in self._excluded_optype_re_list:
|
||||
for optype_re in self._parameters.excluded_optype_re_list:
|
||||
if optype_re.match(op.type):
|
||||
return True
|
||||
return False
|
||||
@ -886,14 +502,15 @@ class TensorTracer(object):
|
||||
def _use_tensor_values_cache(self):
|
||||
"""Returns True if immediate tensors should be first saved to a cache."""
|
||||
|
||||
if self._trace_mode not in set([_TRACE_MODE_NAN_INF,
|
||||
_TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS]):
|
||||
if self._parameters.trace_mode not in set([
|
||||
tensor_tracer_flags.TRACE_MODE_NAN_INF,
|
||||
tensor_tracer_flags.TRACE_MODE_NORM,
|
||||
tensor_tracer_flags.TRACE_MODE_MAX_ABS]):
|
||||
return False
|
||||
if self._trace_dir and _trace_files_need_precreated(self._trace_dir):
|
||||
if (self._parameters.trace_dir and
|
||||
_trace_files_need_precreated(self._parameters.trace_dir)):
|
||||
return True
|
||||
if TensorTracer.use_compact_trace():
|
||||
return True
|
||||
return False
|
||||
return self._parameters.use_compact_trace
|
||||
|
||||
def _save_tensor_value_to_cache_op(self, graph, cache_idx, updates):
|
||||
"""Returns an Op that will save the given updates to an entry in the cache."""
|
||||
@ -917,11 +534,13 @@ class TensorTracer(object):
|
||||
self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, self._submode))
|
||||
if self._included_cores:
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE,
|
||||
self._parameters.trace_mode))
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE,
|
||||
self._parameters.submode))
|
||||
if self._parameters.included_cores:
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
|
||||
len(self._included_cores)))
|
||||
len(self._parameters.included_cores)))
|
||||
else:
|
||||
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
|
||||
self._num_replicas))
|
||||
@ -1061,20 +680,24 @@ class TensorTracer(object):
|
||||
is_nan_producer = math_ops.reduce_any(is_nan_producer > 0)
|
||||
return is_nan_producer
|
||||
|
||||
if self._trace_mode == _TRACE_MODE_FULL_IF_NAN:
|
||||
if (self._parameters.trace_mode ==
|
||||
tensor_tracer_flags.TRACE_MODE_FULL_IF_NAN):
|
||||
return _detect_inf_nan_producer(tensor)
|
||||
if self._trace_mode == _TRACE_MODE_NAN_INF:
|
||||
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
|
||||
return _detect_nan_inf(tensor)
|
||||
if self._trace_mode == _TRACE_MODE_PART_TENSOR:
|
||||
if (self._parameters.trace_mode ==
|
||||
tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
|
||||
return tensor
|
||||
if self._trace_mode == _TRACE_MODE_FULL_TENSOR:
|
||||
if (self._parameters.trace_mode ==
|
||||
tensor_tracer_flags.TRACE_MODE_FULL_TENSOR):
|
||||
return tensor
|
||||
if self._trace_mode == _TRACE_MODE_NORM:
|
||||
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM:
|
||||
return _show_norm(tensor)
|
||||
if self._trace_mode == _TRACE_MODE_MAX_ABS:
|
||||
if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS:
|
||||
return _show_max_abs(tensor)
|
||||
raise RuntimeError(
|
||||
'Tensor trace fun for %s is not yet implemented' % self._trace_mode)
|
||||
'Tensor trace fun for %s is not yet implemented'
|
||||
% self._parameters.trace_mode)
|
||||
|
||||
def _make_tensor_trace_fun(self, tensor_name):
|
||||
"""Makes the tensor tracing function called by outside compilation.
|
||||
@ -1106,7 +729,7 @@ class TensorTracer(object):
|
||||
self._tensorname_idx_map.
|
||||
"""
|
||||
|
||||
if self._submode == _SUBMODE_BRIEF:
|
||||
if self._parameters.is_brief_mode():
|
||||
if tensor_name not in self._tensorname_idx_map:
|
||||
raise ValueError(
|
||||
'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
|
||||
@ -1114,8 +737,8 @@ class TensorTracer(object):
|
||||
else:
|
||||
msg = '"%s"'%tensor_name
|
||||
|
||||
if self._trace_dir:
|
||||
output_path = os.path.join(self._trace_dir, _TRACE_FILE_NAME)
|
||||
if self._parameters.trace_dir:
|
||||
output_path = os.path.join(self._parameters.trace_dir, _TRACE_FILE_NAME)
|
||||
output_stream = _OUTPUT_STREAM_ESCAPE + output_path
|
||||
else:
|
||||
output_stream = sys.stderr
|
||||
@ -1159,27 +782,31 @@ class TensorTracer(object):
|
||||
visited_tensors[input_tensor] = distance + 1
|
||||
return visitor_queue
|
||||
|
||||
tensors_to_print = _get_distance_k_tensors(self._trace_stack_size)
|
||||
tensors_to_print = _get_distance_k_tensors(
|
||||
self._parameters.trace_stack_size)
|
||||
print_ops = [_print_tensor(t.name, -1, t, t) for t in tensors_to_print]
|
||||
with ops.control_dependencies(print_ops):
|
||||
return constant_op.constant(True)
|
||||
|
||||
if self._trace_mode == _TRACE_MODE_FULL_IF_NAN:
|
||||
if (self._parameters.trace_mode ==
|
||||
tensor_tracer_flags.TRACE_MODE_FULL_IF_NAN):
|
||||
return _show_full_tensors
|
||||
if self._trace_mode == _TRACE_MODE_PART_TENSOR:
|
||||
if (self._parameters.trace_mode ==
|
||||
tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
|
||||
return _show_part_tensor
|
||||
# The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF,
|
||||
# _TRACE_MODE_NORM, and _TRACE_MODE_MAX_ABS, as related computations are
|
||||
# The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF,
|
||||
# TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are
|
||||
# performed within TPUs and only their results are transferred to CPU.
|
||||
# Simply, print the full tensor for these trace modes.
|
||||
if self._trace_mode in [
|
||||
_TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_FULL_TENSOR,
|
||||
_TRACE_MODE_MAX_ABS
|
||||
]:
|
||||
if self._parameters.trace_mode in [
|
||||
tensor_tracer_flags.TRACE_MODE_NAN_INF,
|
||||
tensor_tracer_flags.TRACE_MODE_NORM,
|
||||
tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
|
||||
tensor_tracer_flags.TRACE_MODE_MAX_ABS]:
|
||||
return _show_full_tensor
|
||||
|
||||
raise RuntimeError('Tensor trace fun for %s is not yet implemented'
|
||||
%self._trace_mode)
|
||||
%self._parameters.trace_mode)
|
||||
|
||||
def _skip_op(self, op_id, op, user_included, user_excluded,
|
||||
in_exec_path=True):
|
||||
@ -1206,7 +833,7 @@ class TensorTracer(object):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_OUTSIDE_OP_RANGE)
|
||||
return True
|
||||
if TensorTracer.less_interesting_op(op):
|
||||
if self._less_interesting_op(op):
|
||||
self._instrument_records[op.name] = TensorTracer.reason(
|
||||
op_id, _REASON_LESS_INTERESTING_OP)
|
||||
return True
|
||||
@ -1251,8 +878,10 @@ class TensorTracer(object):
|
||||
if not out_tensor.get_shape().is_fully_defined():
|
||||
# If trace mode is nan-inf, norm or max, then the tensor will be reduced
|
||||
# to a scalar before the outside compilation call.
|
||||
if self._trace_mode in [
|
||||
_TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS
|
||||
if self._parameters.trace_mode in [
|
||||
tensor_tracer_flags.TRACE_MODE_NAN_INF,
|
||||
tensor_tracer_flags.TRACE_MODE_NORM,
|
||||
tensor_tracer_flags.TRACE_MODE_MAX_ABS
|
||||
]:
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_TENSOR_GET_TRACED)
|
||||
@ -1264,7 +893,7 @@ class TensorTracer(object):
|
||||
rank = len(out_tensor.shape)
|
||||
if rank < 1:
|
||||
# scalar
|
||||
if self._trace_scalar_ops:
|
||||
if self._parameters.trace_scalar_ops:
|
||||
if TensorTracer.unsafe_scalar_trace(out_tensor.op):
|
||||
self._instrument_records[out_tensor.name] = TensorTracer.reason(
|
||||
op_id, _REASON_UNSAFE_SCALAR)
|
||||
@ -1362,23 +991,23 @@ class TensorTracer(object):
|
||||
def _check_trace_files(self):
|
||||
"""Checks if any requirements for trace files are satisfied."""
|
||||
|
||||
if not self._trace_dir:
|
||||
if not self._parameters.trace_dir:
|
||||
# traces will be written to stderr. No need to check trace files.
|
||||
return
|
||||
if _trace_files_need_precreated(self._trace_dir):
|
||||
if _trace_files_need_precreated(self._parameters.trace_dir):
|
||||
for replica_id in range(0, self._num_replicas):
|
||||
trace_file_path = os.path.join(
|
||||
self._trace_dir,
|
||||
self._parameters.trace_dir,
|
||||
_COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id
|
||||
if not gfile.Exists(trace_file_path):
|
||||
raise RuntimeError(
|
||||
'%s must be pre-created with the '
|
||||
'appropriate properties.'%trace_file_path)
|
||||
else:
|
||||
if not gfile.Exists(self._trace_dir):
|
||||
gfile.MkDir(self._trace_dir)
|
||||
if not gfile.Exists(self._trace_dir):
|
||||
raise RuntimeError('Failed to create %s'%self._trace_dir)
|
||||
if not gfile.Exists(self._parameters.trace_dir):
|
||||
gfile.MkDir(self._parameters.trace_dir)
|
||||
if not gfile.Exists(self._parameters.trace_dir):
|
||||
raise RuntimeError('Failed to create %s'%self._parameters.trace_dir)
|
||||
|
||||
def _pre_tracing(self, graph, fetches):
|
||||
"""Work needs to be done prior to TPU or CPU tracing."""
|
||||
@ -1455,8 +1084,8 @@ class TensorTracer(object):
|
||||
replica_id_str = replica_id
|
||||
else:
|
||||
replica_id_str = '%d'%replica_id
|
||||
if self._trace_dir:
|
||||
output_path = os.path.join(self._trace_dir,
|
||||
if self._parameters.trace_dir:
|
||||
output_path = os.path.join(self._parameters.trace_dir,
|
||||
_COMPACT_TRACE_FILE_PREFIX) \
|
||||
+ replica_id_str
|
||||
output_stream = _OUTPUT_STREAM_ESCAPE + output_path
|
||||
@ -1708,12 +1337,12 @@ class TensorTracer(object):
|
||||
predicate_tensor, lambda: trace_fn(out_tensor, out_tensor_name),
|
||||
lambda: constant_op.constant(False)).op
|
||||
|
||||
if self._is_conditional_trace:
|
||||
if self._parameters.is_conditional_trace:
|
||||
trace_op = conditional_trace_fn(processed_out_tensor, out_tensor,
|
||||
tpu_wrap_trace_fn, tensor_name)
|
||||
elif self._included_cores:
|
||||
elif self._parameters.included_cores:
|
||||
should_print = constant_op.constant(False)
|
||||
for core in self._included_cores:
|
||||
for core in self._parameters.included_cores:
|
||||
should_print = gen_math_ops.logical_or(
|
||||
should_print, gen_math_ops.equal(self._replica_id, core))
|
||||
trace_op = conditional_trace_fn(should_print, processed_out_tensor,
|
||||
@ -1800,15 +1429,15 @@ class TensorTracer(object):
|
||||
# Checks for the assumption in _generate_flush_cache_op().
|
||||
raise RuntimeError('num_replicas_per_host (%d) is '
|
||||
'greater than 8'%self._num_replicas_per_host)
|
||||
if self._graph_dump_path:
|
||||
graph_io.write_graph(graph, self._graph_dump_path,
|
||||
if self._parameters.graph_dump_path:
|
||||
graph_io.write_graph(graph, self._parameters.graph_dump_path,
|
||||
'graph_before_tt.pbtxt')
|
||||
with graph.as_default():
|
||||
self._add_replica_id_to_graph()
|
||||
tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
|
||||
on_tpu=True)
|
||||
if self._graph_dump_path:
|
||||
graph_io.write_graph(graph, self._graph_dump_path,
|
||||
if self._parameters.graph_dump_path:
|
||||
graph_io.write_graph(graph, self._parameters.graph_dump_path,
|
||||
'graph_after_tt.pbtxt')
|
||||
return tensor_fetches
|
||||
|
||||
@ -1843,13 +1472,13 @@ class TensorTracer(object):
|
||||
self._num_replicas_per_host = 1
|
||||
self._num_hosts = 1
|
||||
self._replica_id = 0
|
||||
if self._graph_dump_path:
|
||||
graph_io.write_graph(graph, self._graph_dump_path,
|
||||
if self._parameters.graph_dump_path:
|
||||
graph_io.write_graph(graph, self._parameters.graph_dump_path,
|
||||
'graph_before_tt.pbtxt')
|
||||
with graph.as_default():
|
||||
tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
|
||||
on_tpu=False)
|
||||
if self._graph_dump_path:
|
||||
graph_io.write_graph(graph, self._graph_dump_path,
|
||||
if self._parameters.graph_dump_path:
|
||||
graph_io.write_graph(graph, self._parameters.graph_dump_path,
|
||||
'graph_after_tt.pbtxt')
|
||||
return tensor_fetches
|
||||
|
379
tensorflow/python/tpu/tensor_tracer_flags.py
Normal file
379
tensorflow/python/tpu/tensor_tracer_flags.py
Normal file
@ -0,0 +1,379 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ========================================================================
|
||||
"""Utilities to handle tensor tracer parameters."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
TRACE_MODE_NAN_INF = 'nan-inf'
|
||||
TRACE_MODE_PART_TENSOR = 'part-tensor'
|
||||
TRACE_MODE_FULL_TENSOR = 'full-tensor'
|
||||
TRACE_MODE_FULL_IF_NAN = 'trace-back-if-nan'
|
||||
TRACE_MODE_NORM = 'norm'
|
||||
TRACE_MODE_MAX_ABS = 'max-abs'
|
||||
_FLAG_NAME_TRACE_STACK_SIZE = 'trace_stack_size'
|
||||
_SUBMODE_BRIEF = 'brief'
|
||||
_SUBMODE_DETAILED = 'detailed'
|
||||
_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
|
||||
_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
|
||||
_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"')
|
||||
_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)')
|
||||
_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*')
|
||||
_FLAG_NAME_ENABLE = 'enable'
|
||||
_FLAG_NAME_TRACE_MODE = 'trace_mode'
|
||||
_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace'
|
||||
_FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar'
|
||||
_FLAG_NAME_TRACE_BEFORE_OPS = 'trace_before_included_ops'
|
||||
_FLAG_NAME_TRACE_AFTER_OPS = 'trace_after_included_ops'
|
||||
_FLAG_NAME_SUBMODE = 'submode'
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops'
|
||||
_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
|
||||
_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
|
||||
_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
|
||||
_FLAG_NAME_INCLUDED_CORES = 'included_cores'
|
||||
_FLAG_NAME_TRACE_DIR = 'trace_dir'
|
||||
_FLAG_NAME_REPORT_FILE = 'report_file'
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
|
||||
_FLAG_NAME_OP_RANGE = 'op_range'
|
||||
# Folder to dump the pre (before tensor tracer updates) and post graphs (after
|
||||
# tensor tracer updates).
|
||||
_FLAG_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs'
|
||||
_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
|
||||
_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
|
||||
|
||||
|
||||
class TTParameters(object):
|
||||
"""A class that handles the parameters of Tensor Tracer."""
|
||||
|
||||
def __init__(self, env=None):
|
||||
if env:
|
||||
self._env = env
|
||||
else:
|
||||
self._env = os.environ
|
||||
self._validate_flag_names()
|
||||
self.trace_mode = self._get_trace_mode()
|
||||
self.submode = self._get_submode()
|
||||
self.trace_dir = self._get_trace_dir()
|
||||
self.report_file_path = self._get_report_filepath()
|
||||
self.op_range = self._get_op_range()
|
||||
self.excluded_opname_re_list = self._flag_value_to_re_list(
|
||||
_FLAG_NAME_EXCLUDED_OPNAMES)
|
||||
self.excluded_optype_re_list = self._flag_value_to_re_list(
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES)
|
||||
|
||||
self.included_opname_re_list = self._flag_value_to_re_list(
|
||||
_FLAG_NAME_INCLUDED_OPNAMES)
|
||||
self.included_optype_re_list = self._flag_value_to_re_list(
|
||||
_FLAG_NAME_INCLUDED_OPTYPES)
|
||||
|
||||
self.is_conditional_trace = self._is_conditional_trace_mode()
|
||||
self.trace_scalar_ops = self.is_flag_on(_FLAG_NAME_TRACE_SCALAR_OPS)
|
||||
self.use_compact_trace = self.is_flag_on(_FLAG_NAME_USE_COMPACT_TRACE)
|
||||
|
||||
# _trace_ops_before_included and _trace_ops_after_included denotes to depth
|
||||
# of tracing relative to the ops given in --included_opnames or
|
||||
# --included_optypes
|
||||
# For example, in the below graph
|
||||
# op1 --> op2 --> op3 --> op4 --> op5
|
||||
# If --included_opnames=op3 then only op3 will be traced.
|
||||
# If also --trace_before_included_ops=2 (_trace_ops_before_included), then
|
||||
# op1 and op2 will be traced as they are at most 2 hops apart from an
|
||||
# included op. Similarly, if --trace_after_included_ops=2, then op4 and op5
|
||||
# will also be traced.
|
||||
self.trace_ops_before_included = self._get_flag_int_value(
|
||||
_FLAG_NAME_TRACE_BEFORE_OPS, 0)
|
||||
self.trace_ops_after_included = self._get_flag_int_value(
|
||||
_FLAG_NAME_TRACE_AFTER_OPS, 0)
|
||||
self.trace_stack_size = self._get_flag_int_value(
|
||||
_FLAG_NAME_TRACE_STACK_SIZE, 1)
|
||||
_, self.graph_dump_path = self.get_flag_value(
|
||||
_FLAG_DUMP_BEFORE_AFTER_GRAPHS)
|
||||
self.included_cores = self._flag_value_as_int_list(
|
||||
_FLAG_NAME_INCLUDED_CORES)
|
||||
self.include_less_interesting_ops, _ = self.get_flag_value(
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS)
|
||||
|
||||
def _is_conditional_trace_mode(self):
|
||||
return self.trace_mode == TRACE_MODE_FULL_IF_NAN
|
||||
|
||||
def _get_report_filepath(self):
|
||||
"""Sets the path of the output report file."""
|
||||
|
||||
found, report_file_path = self.get_flag_value(
|
||||
_FLAG_NAME_REPORT_FILE)
|
||||
if found and report_file_path \
|
||||
and self.use_test_undeclared_outputs_dir():
|
||||
if os.path.isabs(report_file_path):
|
||||
raise ValueError('If use_test_undeclared_outputs_dir is set,'
|
||||
'report_file_path cannot be an absolute path (%s)'
|
||||
%report_file_path)
|
||||
outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
|
||||
report_file_path = os.path.join(outputs_dir, report_file_path)
|
||||
return report_file_path
|
||||
|
||||
def _get_op_range(self):
|
||||
"""Sets the index range of the Ops that we will consider tracing."""
|
||||
found, op_range = self.get_flag_value(_FLAG_NAME_OP_RANGE)
|
||||
if not found or not op_range:
|
||||
op_range = (-1, -1) # this means including all ops.
|
||||
return op_range
|
||||
match = _OP_RANGE_PAT.match(op_range)
|
||||
if not match:
|
||||
op_range = (-1, -1) # this means including all ops.
|
||||
return op_range
|
||||
op_range = (int(match.group(1)), int(match.group(2)))
|
||||
return op_range
|
||||
|
||||
def _get_trace_dir(self):
|
||||
found, trace_dir = self.get_flag_value(_FLAG_NAME_TRACE_DIR)
|
||||
if found and trace_dir \
|
||||
and self.use_test_undeclared_outputs_dir():
|
||||
raise ValueError('Cannot not use --%s and --%s at the same time'
|
||||
%(_FLAG_NAME_TRACE_DIR,
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR))
|
||||
if self.use_test_undeclared_outputs_dir():
|
||||
trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
|
||||
return trace_dir
|
||||
|
||||
def _get_trace_mode(self):
|
||||
"""Checks if the given trace mode is valid."""
|
||||
|
||||
found, trace_mode = self.get_flag_value(_FLAG_NAME_TRACE_MODE)
|
||||
if not found or not trace_mode:
|
||||
trace_mode = TRACE_MODE_NORM
|
||||
valid_trace_modes = [
|
||||
TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR,
|
||||
TRACE_MODE_NORM, TRACE_MODE_MAX_ABS, TRACE_MODE_FULL_IF_NAN
|
||||
]
|
||||
if trace_mode not in valid_trace_modes:
|
||||
raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
|
||||
'Valid trace modes are: %s'%(trace_mode,
|
||||
valid_trace_modes))
|
||||
return trace_mode
|
||||
|
||||
def is_brief_mode(self):
|
||||
return self.submode == _SUBMODE_BRIEF
|
||||
|
||||
def _get_submode(self):
|
||||
"""Checks if the given submode is valid."""
|
||||
|
||||
found, submode = self.get_flag_value(_FLAG_NAME_SUBMODE)
|
||||
if not found or not submode:
|
||||
submode = _SUBMODE_DETAILED
|
||||
if not submode:
|
||||
return
|
||||
valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF]
|
||||
if submode not in valid_submodes:
|
||||
raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.'
|
||||
'Valid submodes are: %s'%(submode,
|
||||
valid_submodes))
|
||||
return submode
|
||||
|
||||
@staticmethod
|
||||
def match_next_flag(flags, pos):
|
||||
"""Returns the match for the next TensorTracer flag.
|
||||
|
||||
Args:
|
||||
flags: a string that contains the flags.
|
||||
pos: where in flags to start the search.
|
||||
|
||||
Returns:
|
||||
A pair where the first element is the regular-expression
|
||||
match found and the second element indicates if the match
|
||||
has a value.
|
||||
"""
|
||||
|
||||
match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos)
|
||||
if match:
|
||||
return match, True
|
||||
match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos)
|
||||
if match:
|
||||
return match, True
|
||||
match = _FLAG_NO_QUOTE_PAT.match(flags, pos)
|
||||
if match:
|
||||
return match, True
|
||||
match = _FLAG_NO_EQUAL_PAT.match(flags, pos)
|
||||
if match:
|
||||
# The flag is found but is not given a value.
|
||||
return match, False
|
||||
# The flag is not found.
|
||||
return None, False
|
||||
|
||||
def _validate_flag_names(self):
|
||||
"""Validates if the TensorTrace flags passed are valid."""
|
||||
valid_flag_names = [
|
||||
_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, _FLAG_NAME_USE_COMPACT_TRACE,
|
||||
_FLAG_NAME_TRACE_SCALAR_OPS, _FLAG_NAME_TRACE_BEFORE_OPS,
|
||||
_FLAG_NAME_TRACE_AFTER_OPS, _FLAG_NAME_TRACE_STACK_SIZE,
|
||||
_FLAG_NAME_SUBMODE, _FLAG_NAME_EXCLUDED_OPNAMES,
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES, _FLAG_NAME_INCLUDED_OPNAMES,
|
||||
_FLAG_NAME_INCLUDED_OPTYPES, _FLAG_NAME_TRACE_DIR,
|
||||
_FLAG_NAME_INCLUDED_CORES, _FLAG_NAME_REPORT_FILE,
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, _FLAG_NAME_OP_RANGE,
|
||||
_FLAG_DUMP_BEFORE_AFTER_GRAPHS
|
||||
]
|
||||
tensor_tracer_flags = self._env.get(_FLAGS_ENV_VAR)
|
||||
if not tensor_tracer_flags:
|
||||
return
|
||||
pos = 0
|
||||
while True:
|
||||
match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos)
|
||||
if not match:
|
||||
break
|
||||
flag_name = match.group(1)
|
||||
if flag_name not in valid_flag_names:
|
||||
raise ValueError(
|
||||
'The flag name "%s" passed via the environment variable "%s" '
|
||||
'is invalid. Valid flag names are:'
|
||||
'\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names))
|
||||
pos = match.end()
|
||||
|
||||
def _flag_value_as_int_list(self, wanted_flag_name):
|
||||
"""Returns the integer list of a TensorTracer flag.
|
||||
|
||||
Args:
|
||||
wanted_flag_name: the name of the flag we are looking for.
|
||||
|
||||
Returns:
|
||||
the value of the flag.
|
||||
Raises:
|
||||
RuntimeError: If supposedly deadcode is reached.
|
||||
"""
|
||||
int_list = []
|
||||
found, flag_value = self.get_flag_value(wanted_flag_name)
|
||||
|
||||
if found:
|
||||
try:
|
||||
integer_values = flag_value.split(',')
|
||||
int_list = [int(int_val) for int_val in integer_values]
|
||||
except ValueError:
|
||||
logging.warning('Cannot convert %s to int for flag %s', int_list,
|
||||
wanted_flag_name)
|
||||
return int_list
|
||||
|
||||
def _get_flag_int_value(self, wanted_flag_name, default_value):
|
||||
"""Returns the int value of a TensorTracer flag.
|
||||
|
||||
Args:
|
||||
wanted_flag_name: the name of the flag we are looking for.
|
||||
default_value: the default value for the flag, if not provided.
|
||||
Returns:
|
||||
the value of the flag.
|
||||
Raises:
|
||||
RuntimeError: If supposedly deadcode is reached.
|
||||
"""
|
||||
flag_int_value = default_value
|
||||
found, flag_value = self.get_flag_value(wanted_flag_name)
|
||||
|
||||
if found:
|
||||
try:
|
||||
flag_int_value = int(flag_value)
|
||||
except ValueError:
|
||||
logging.warning('Cannot convert %s to int for flag %s' % (
|
||||
flag_int_value, wanted_flag_name))
|
||||
return flag_int_value
|
||||
|
||||
def get_flag_value(self, wanted_flag_name):
|
||||
"""Returns the value of a TensorTracer flags.
|
||||
|
||||
Args:
|
||||
wanted_flag_name: the name of the flag we are looking for.
|
||||
|
||||
Returns:
|
||||
A pair where the first element indicates if the flag is
|
||||
found and the second element is the value of the flag.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If supposedly deadcode is reached.
|
||||
"""
|
||||
|
||||
tensor_tracer_flags = self._env.get(_FLAGS_ENV_VAR)
|
||||
if not tensor_tracer_flags:
|
||||
return False, None
|
||||
pos = 0
|
||||
while True:
|
||||
match, has_value = TTParameters.match_next_flag(
|
||||
tensor_tracer_flags, pos)
|
||||
if not match:
|
||||
return False, None
|
||||
flag_name = match.group(1)
|
||||
if has_value:
|
||||
flag_value = match.group(2)
|
||||
else:
|
||||
flag_value = None
|
||||
if flag_name == wanted_flag_name:
|
||||
return True, flag_value
|
||||
pos = match.end()
|
||||
raise RuntimeError('Should not reach here.')
|
||||
|
||||
def _flag_value_to_re_list(self, flag_name):
|
||||
"""Converts list of strings to compiled RE."""
|
||||
|
||||
re_list = []
|
||||
found, flag_value = self.get_flag_value(flag_name)
|
||||
if not found or not flag_value:
|
||||
return re_list
|
||||
list_of_values = flag_value.split()
|
||||
for v in list_of_values:
|
||||
r = re.compile(v)
|
||||
re_list.append(r)
|
||||
return re_list
|
||||
|
||||
def is_flag_on(self, flag_name):
|
||||
"""Returns True if the given flag is on."""
|
||||
|
||||
found, flag_value = self.get_flag_value(flag_name)
|
||||
if not found:
|
||||
return False
|
||||
if flag_value is None:
|
||||
return True
|
||||
# Depends on the flag value.
|
||||
flag_value = flag_value.lower()
|
||||
enabled = flag_value in ['1', 't', 'true', 'y', 'yes']
|
||||
return enabled
|
||||
|
||||
def is_enabled(self):
|
||||
"""Returns True if TensorTracer is enabled."""
|
||||
|
||||
if self.is_flag_on(_FLAG_NAME_ENABLE):
|
||||
logging.info('Tensor Tracer is enabled with flags %s.' %
|
||||
self._env.get(_FLAGS_ENV_VAR))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def use_test_undeclared_outputs_dir(self):
|
||||
"""Decides the output directory of the report and trace files.
|
||||
|
||||
Args:
|
||||
None.
|
||||
|
||||
Returns:
|
||||
True if the output files should be written to the
|
||||
test-undeclared-outputs-directory defined via an
|
||||
env variable.
|
||||
"""
|
||||
|
||||
return self.is_flag_on(_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)
|
Loading…
Reference in New Issue
Block a user