[Tensor tracer]: add a new mode: trace-back-if-nan.
PiperOrigin-RevId: 241820394
This commit is contained in:
parent
6009b2288a
commit
368474cc4d
@ -50,6 +50,8 @@ _TRACE_MODE_NAN_INF = 'nan-inf'
|
||||
_TRACE_MODE_PART_TENSOR = 'part-tensor'
|
||||
_TRACE_MODE_PART_TENSOR_SIZE = 3
|
||||
_TRACE_MODE_FULL_TENSOR = 'full-tensor'
|
||||
_TRACE_MODE_FULL_IF_NAN = 'trace-back-if-nan'
|
||||
_FLAG_NAME_TRACE_STACK_SIZE = 'trace_stack_size'
|
||||
_TRACE_MODE_NORM = 'norm'
|
||||
_TRACE_MODE_MAX_ABS = 'max-abs'
|
||||
_SUBMODE_BRIEF = 'brief'
|
||||
@ -277,22 +279,17 @@ class TensorTracer(object):
|
||||
@staticmethod
|
||||
def validate_flag_names():
|
||||
"""Validates if the TensorTrace flags passed are valid."""
|
||||
valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE,
|
||||
_FLAG_NAME_USE_COMPACT_TRACE,
|
||||
_FLAG_NAME_TRACE_SCALAR_OPS,
|
||||
_FLAG_NAME_TRACE_BEFORE_OPS,
|
||||
_FLAG_NAME_TRACE_AFTER_OPS,
|
||||
_FLAG_NAME_SUBMODE,
|
||||
_FLAG_NAME_EXCLUDED_OPNAMES,
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES,
|
||||
_FLAG_NAME_INCLUDED_OPNAMES,
|
||||
_FLAG_NAME_INCLUDED_OPTYPES,
|
||||
_FLAG_NAME_TRACE_DIR,
|
||||
_FLAG_NAME_REPORT_FILE,
|
||||
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS,
|
||||
_FLAG_NAME_OP_RANGE,
|
||||
_FLAG_DUMP_BEFORE_AFTER_GRAPHS]
|
||||
valid_flag_names = [
|
||||
_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, _FLAG_NAME_USE_COMPACT_TRACE,
|
||||
_FLAG_NAME_TRACE_SCALAR_OPS, _FLAG_NAME_TRACE_BEFORE_OPS,
|
||||
_FLAG_NAME_TRACE_AFTER_OPS, _FLAG_NAME_TRACE_STACK_SIZE,
|
||||
_FLAG_NAME_SUBMODE, _FLAG_NAME_EXCLUDED_OPNAMES,
|
||||
_FLAG_NAME_EXCLUDED_OPTYPES, _FLAG_NAME_INCLUDED_OPNAMES,
|
||||
_FLAG_NAME_INCLUDED_OPTYPES, _FLAG_NAME_TRACE_DIR,
|
||||
_FLAG_NAME_REPORT_FILE, _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
|
||||
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, _FLAG_NAME_OP_RANGE,
|
||||
_FLAG_DUMP_BEFORE_AFTER_GRAPHS
|
||||
]
|
||||
tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR)
|
||||
if not tensor_tracer_flags:
|
||||
return
|
||||
@ -336,7 +333,7 @@ class TensorTracer(object):
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def flag_value_as_int(wanted_flag_name, default_value):
|
||||
def get_flag_int_value(wanted_flag_name, default_value):
|
||||
"""Returns the int value of a TensorTracer flag.
|
||||
|
||||
Args:
|
||||
@ -458,9 +455,10 @@ class TensorTracer(object):
|
||||
def check_trace_mode(trace_mode):
|
||||
"""Checks if the given trace mode is valid."""
|
||||
|
||||
valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR,
|
||||
_TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM,
|
||||
_TRACE_MODE_MAX_ABS]
|
||||
valid_trace_modes = [
|
||||
_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, _TRACE_MODE_FULL_TENSOR,
|
||||
_TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS, _TRACE_MODE_FULL_IF_NAN
|
||||
]
|
||||
if trace_mode not in valid_trace_modes:
|
||||
raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
|
||||
'Valid trace modes are: %s'%(trace_mode,
|
||||
@ -650,6 +648,10 @@ class TensorTracer(object):
|
||||
tensorname_idx_map[output_tensor.name] = len(tensor_list)-1
|
||||
return (opname_idx_map, tensor_list, tensorname_idx_map)
|
||||
|
||||
@staticmethod
|
||||
def is_conditional_trace_mode(trace_mode):
|
||||
return trace_mode == _TRACE_MODE_FULL_IF_NAN
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a TensorTracer.
|
||||
|
||||
@ -680,6 +682,8 @@ class TensorTracer(object):
|
||||
self._num_hosts = None
|
||||
self._replica_id = None
|
||||
self._included_op_full_names = set()
|
||||
self._is_conditional_trace = TensorTracer.is_conditional_trace_mode(
|
||||
self._trace_mode)
|
||||
self._trace_scalar_ops = TensorTracer._is_flag_on(
|
||||
_FLAG_NAME_TRACE_SCALAR_OPS)
|
||||
|
||||
@ -693,10 +697,12 @@ class TensorTracer(object):
|
||||
# op1 and op2 will be traced as they are at most 2 hops apart from an
|
||||
# included op. Similarly, if --trace_after_included_ops=2, then op4 and op5
|
||||
# will also be traced.
|
||||
self._trace_ops_before_included = TensorTracer.flag_value_as_int(
|
||||
self._trace_ops_before_included = TensorTracer.get_flag_int_value(
|
||||
_FLAG_NAME_TRACE_BEFORE_OPS, 0)
|
||||
self._trace_ops_after_included = TensorTracer.flag_value_as_int(
|
||||
self._trace_ops_after_included = TensorTracer.get_flag_int_value(
|
||||
_FLAG_NAME_TRACE_AFTER_OPS, 0)
|
||||
self._trace_stack_size = TensorTracer.get_flag_int_value(
|
||||
_FLAG_NAME_TRACE_STACK_SIZE, 1)
|
||||
|
||||
_, self._graph_dump_path = TensorTracer.get_flag_value(
|
||||
_FLAG_DUMP_BEFORE_AFTER_GRAPHS)
|
||||
@ -991,6 +997,25 @@ class TensorTracer(object):
|
||||
output_tensor = array_ops.reshape(output_tensor, [1])
|
||||
return output_tensor
|
||||
|
||||
def _detect_inf_nan_producer(tensor):
|
||||
"""Checks if the tensor is the first NaN/Inf tensor in the computation path."""
|
||||
if tensor.op.inputs:
|
||||
inp_check = [
|
||||
_detect_nan_inf(inp_tensor) for inp_tensor in tensor.op.inputs
|
||||
]
|
||||
is_any_input_inf_nan = math_ops.add_n(inp_check)
|
||||
else:
|
||||
is_any_input_inf_nan = constant_op.constant(0, dtypes.bool)
|
||||
is_current_tensor_inf_nan = _detect_nan_inf(tensor)
|
||||
# An op is NaN/INF producer only when all inputs are nan/inf free (
|
||||
# is_any_input_inf_nan = 0), and its output has nan/inf (
|
||||
# is_current_tensor_inf_nan=1). Below will be 1 if op nan/inf is producer.
|
||||
is_nan_producer = is_current_tensor_inf_nan - is_any_input_inf_nan
|
||||
is_nan_producer = math_ops.reduce_any(is_nan_producer > 0)
|
||||
return is_nan_producer
|
||||
|
||||
if self._trace_mode == _TRACE_MODE_FULL_IF_NAN:
|
||||
return _detect_inf_nan_producer(tensor)
|
||||
if self._trace_mode == _TRACE_MODE_NAN_INF:
|
||||
return _detect_nan_inf(tensor)
|
||||
if self._trace_mode == _TRACE_MODE_PART_TENSOR:
|
||||
@ -1064,6 +1089,36 @@ class TensorTracer(object):
|
||||
|
||||
return _print_tensor(tensor_name, -1, tensor, tensor)
|
||||
|
||||
def _show_full_tensors(tensor):
|
||||
"""Prints the full tensor values for the tensors that are _trace_stack_size hops away from a given tensor."""
|
||||
|
||||
def _get_distance_k_tensors(k_before=0):
|
||||
"""Returns the tensors that are at most k_before hops away from the tensor."""
|
||||
if k_before < 0:
|
||||
return []
|
||||
visited_tensors = {tensor: 0}
|
||||
visitor_queue = [tensor]
|
||||
head = 0
|
||||
while head < len(visitor_queue):
|
||||
current_tensor = visitor_queue[head]
|
||||
head += 1
|
||||
distance = visited_tensors[current_tensor]
|
||||
if distance == k_before:
|
||||
break
|
||||
for input_tensor in current_tensor.op.inputs:
|
||||
if input_tensor in visited_tensors:
|
||||
continue
|
||||
visitor_queue.append(input_tensor)
|
||||
visited_tensors[input_tensor] = distance + 1
|
||||
return visitor_queue
|
||||
|
||||
tensors_to_print = _get_distance_k_tensors(self._trace_stack_size)
|
||||
print_ops = [_print_tensor(t.name, -1, t, t) for t in tensors_to_print]
|
||||
with ops.control_dependencies(print_ops):
|
||||
return constant_op.constant(0)
|
||||
|
||||
if self._trace_mode == _TRACE_MODE_FULL_IF_NAN:
|
||||
return _show_full_tensors
|
||||
if self._trace_mode == _TRACE_MODE_PART_TENSOR:
|
||||
return _show_part_tensor
|
||||
# The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF,
|
||||
@ -1589,12 +1644,28 @@ class TensorTracer(object):
|
||||
trace_op = self._save_tensor_value_to_cache_op(graph,
|
||||
cache_idx,
|
||||
processed_out_tensor)
|
||||
elif on_tpu:
|
||||
trace_op = tpu.outside_compilation(
|
||||
self._make_tensor_trace_fun(tensor_name), processed_out_tensor)
|
||||
else:
|
||||
trace_fun = self._make_tensor_trace_fun(tensor_name)
|
||||
trace_op = trace_fun(processed_out_tensor)
|
||||
|
||||
def tpu_wrap_trace_fn(tensor, out_tensor_name):
|
||||
"""Wraps the trace_fn with outside compilation if on TPUs."""
|
||||
tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name)
|
||||
if on_tpu:
|
||||
return tpu.outside_compilation(tensor_trace_fn, tensor)
|
||||
else:
|
||||
return tensor_trace_fn(tensor)
|
||||
|
||||
def conditional_trace_fn(predicate_tensor, out_tensor, trace_fn,
|
||||
out_tensor_name):
|
||||
"""Creates a cond op that traces the out_tensor if predicate is satisfied."""
|
||||
return control_flow_ops.cond(
|
||||
predicate_tensor, lambda: trace_fn(out_tensor, out_tensor_name),
|
||||
lambda: constant_op.constant(0)).op
|
||||
|
||||
if self._is_conditional_trace:
|
||||
trace_op = conditional_trace_fn(processed_out_tensor, out_tensor,
|
||||
tpu_wrap_trace_fn, tensor_name)
|
||||
else:
|
||||
trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name)
|
||||
|
||||
if is_a_fetched_tensor:
|
||||
tracing_ops.append(trace_op)
|
||||
@ -1727,5 +1798,3 @@ class TensorTracer(object):
|
||||
graph_io.write_graph(graph, self._graph_dump_path,
|
||||
'graph_after_tt.pbtxt')
|
||||
return tensor_fetches
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user