diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py index c0de9953adf..20d0821c78f 100644 --- a/tensorflow/python/tpu/tensor_tracer.py +++ b/tensorflow/python/tpu/tensor_tracer.py @@ -50,6 +50,8 @@ _TRACE_MODE_NAN_INF = 'nan-inf' _TRACE_MODE_PART_TENSOR = 'part-tensor' _TRACE_MODE_PART_TENSOR_SIZE = 3 _TRACE_MODE_FULL_TENSOR = 'full-tensor' +_TRACE_MODE_FULL_IF_NAN = 'trace-back-if-nan' +_FLAG_NAME_TRACE_STACK_SIZE = 'trace_stack_size' _TRACE_MODE_NORM = 'norm' _TRACE_MODE_MAX_ABS = 'max-abs' _SUBMODE_BRIEF = 'brief' @@ -277,22 +279,17 @@ class TensorTracer(object): @staticmethod def validate_flag_names(): """Validates if the TensorTrace flags passed are valid.""" - valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, - _FLAG_NAME_USE_COMPACT_TRACE, - _FLAG_NAME_TRACE_SCALAR_OPS, - _FLAG_NAME_TRACE_BEFORE_OPS, - _FLAG_NAME_TRACE_AFTER_OPS, - _FLAG_NAME_SUBMODE, - _FLAG_NAME_EXCLUDED_OPNAMES, - _FLAG_NAME_EXCLUDED_OPTYPES, - _FLAG_NAME_INCLUDED_OPNAMES, - _FLAG_NAME_INCLUDED_OPTYPES, - _FLAG_NAME_TRACE_DIR, - _FLAG_NAME_REPORT_FILE, - _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, - _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, - _FLAG_NAME_OP_RANGE, - _FLAG_DUMP_BEFORE_AFTER_GRAPHS] + valid_flag_names = [ + _FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, _FLAG_NAME_USE_COMPACT_TRACE, + _FLAG_NAME_TRACE_SCALAR_OPS, _FLAG_NAME_TRACE_BEFORE_OPS, + _FLAG_NAME_TRACE_AFTER_OPS, _FLAG_NAME_TRACE_STACK_SIZE, + _FLAG_NAME_SUBMODE, _FLAG_NAME_EXCLUDED_OPNAMES, + _FLAG_NAME_EXCLUDED_OPTYPES, _FLAG_NAME_INCLUDED_OPNAMES, + _FLAG_NAME_INCLUDED_OPTYPES, _FLAG_NAME_TRACE_DIR, + _FLAG_NAME_REPORT_FILE, _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, _FLAG_NAME_OP_RANGE, + _FLAG_DUMP_BEFORE_AFTER_GRAPHS + ] tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) if not tensor_tracer_flags: return @@ -336,7 +333,7 @@ class TensorTracer(object): return result @staticmethod - def flag_value_as_int(wanted_flag_name, default_value): + def get_flag_int_value(wanted_flag_name, default_value): """Returns the int value of a TensorTracer flag. Args: @@ -458,9 +455,10 @@ class TensorTracer(object): def check_trace_mode(trace_mode): """Checks if the given trace mode is valid.""" - valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, - _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM, - _TRACE_MODE_MAX_ABS] + valid_trace_modes = [ + _TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, _TRACE_MODE_FULL_TENSOR, + _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS, _TRACE_MODE_FULL_IF_NAN + ] if trace_mode not in valid_trace_modes: raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 'Valid trace modes are: %s'%(trace_mode, @@ -650,6 +648,10 @@ class TensorTracer(object): tensorname_idx_map[output_tensor.name] = len(tensor_list)-1 return (opname_idx_map, tensor_list, tensorname_idx_map) + @staticmethod + def is_conditional_trace_mode(trace_mode): + return trace_mode == _TRACE_MODE_FULL_IF_NAN + def __init__(self): """Initializes a TensorTracer. @@ -680,6 +682,8 @@ class TensorTracer(object): self._num_hosts = None self._replica_id = None self._included_op_full_names = set() + self._is_conditional_trace = TensorTracer.is_conditional_trace_mode( + self._trace_mode) self._trace_scalar_ops = TensorTracer._is_flag_on( _FLAG_NAME_TRACE_SCALAR_OPS) @@ -693,10 +697,12 @@ class TensorTracer(object): # op1 and op2 will be traced as they are at most 2 hops apart from an # included op. Similarly, if --trace_after_included_ops=2, then op4 and op5 # will also be traced. - self._trace_ops_before_included = TensorTracer.flag_value_as_int( + self._trace_ops_before_included = TensorTracer.get_flag_int_value( _FLAG_NAME_TRACE_BEFORE_OPS, 0) - self._trace_ops_after_included = TensorTracer.flag_value_as_int( + self._trace_ops_after_included = TensorTracer.get_flag_int_value( _FLAG_NAME_TRACE_AFTER_OPS, 0) + self._trace_stack_size = TensorTracer.get_flag_int_value( + _FLAG_NAME_TRACE_STACK_SIZE, 1) _, self._graph_dump_path = TensorTracer.get_flag_value( _FLAG_DUMP_BEFORE_AFTER_GRAPHS) @@ -991,6 +997,25 @@ class TensorTracer(object): output_tensor = array_ops.reshape(output_tensor, [1]) return output_tensor + def _detect_inf_nan_producer(tensor): + """Checks if the tensor is the first NaN/Inf tensor in the computation path.""" + if tensor.op.inputs: + inp_check = [ + _detect_nan_inf(inp_tensor) for inp_tensor in tensor.op.inputs + ] + is_any_input_inf_nan = math_ops.add_n(inp_check) + else: + is_any_input_inf_nan = constant_op.constant(0, dtypes.bool) + is_current_tensor_inf_nan = _detect_nan_inf(tensor) + # An op is NaN/INF producer only when all inputs are nan/inf free ( + # is_any_input_inf_nan = 0), and its output has nan/inf ( + # is_current_tensor_inf_nan=1). Below will be 1 if op nan/inf is producer. + is_nan_producer = is_current_tensor_inf_nan - is_any_input_inf_nan + is_nan_producer = math_ops.reduce_any(is_nan_producer > 0) + return is_nan_producer + + if self._trace_mode == _TRACE_MODE_FULL_IF_NAN: + return _detect_inf_nan_producer(tensor) if self._trace_mode == _TRACE_MODE_NAN_INF: return _detect_nan_inf(tensor) if self._trace_mode == _TRACE_MODE_PART_TENSOR: @@ -1064,6 +1089,36 @@ class TensorTracer(object): return _print_tensor(tensor_name, -1, tensor, tensor) + def _show_full_tensors(tensor): + """Prints the full tensor values for the tensors that are _trace_stack_size hops away from a given tensor.""" + + def _get_distance_k_tensors(k_before=0): + """Returns the tensors that are at most k_before hops away from the tensor.""" + if k_before < 0: + return [] + visited_tensors = {tensor: 0} + visitor_queue = [tensor] + head = 0 + while head < len(visitor_queue): + current_tensor = visitor_queue[head] + head += 1 + distance = visited_tensors[current_tensor] + if distance == k_before: + break + for input_tensor in current_tensor.op.inputs: + if input_tensor in visited_tensors: + continue + visitor_queue.append(input_tensor) + visited_tensors[input_tensor] = distance + 1 + return visitor_queue + + tensors_to_print = _get_distance_k_tensors(self._trace_stack_size) + print_ops = [_print_tensor(t.name, -1, t, t) for t in tensors_to_print] + with ops.control_dependencies(print_ops): + return constant_op.constant(0) + + if self._trace_mode == _TRACE_MODE_FULL_IF_NAN: + return _show_full_tensors if self._trace_mode == _TRACE_MODE_PART_TENSOR: return _show_part_tensor # The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF, @@ -1589,12 +1644,28 @@ class TensorTracer(object): trace_op = self._save_tensor_value_to_cache_op(graph, cache_idx, processed_out_tensor) - elif on_tpu: - trace_op = tpu.outside_compilation( - self._make_tensor_trace_fun(tensor_name), processed_out_tensor) else: - trace_fun = self._make_tensor_trace_fun(tensor_name) - trace_op = trace_fun(processed_out_tensor) + + def tpu_wrap_trace_fn(tensor, out_tensor_name): + """Wraps the trace_fn with outside compilation if on TPUs.""" + tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name) + if on_tpu: + return tpu.outside_compilation(tensor_trace_fn, tensor) + else: + return tensor_trace_fn(tensor) + + def conditional_trace_fn(predicate_tensor, out_tensor, trace_fn, + out_tensor_name): + """Creates a cond op that traces the out_tensor if predicate is satisfied.""" + return control_flow_ops.cond( + predicate_tensor, lambda: trace_fn(out_tensor, out_tensor_name), + lambda: constant_op.constant(0)).op + + if self._is_conditional_trace: + trace_op = conditional_trace_fn(processed_out_tensor, out_tensor, + tpu_wrap_trace_fn, tensor_name) + else: + trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name) if is_a_fetched_tensor: tracing_ops.append(trace_op) @@ -1727,5 +1798,3 @@ class TensorTracer(object): graph_io.write_graph(graph, self._graph_dump_path, 'graph_after_tt.pbtxt') return tensor_fetches - -