diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index 70baea203cc..a1494e3660b 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -21,44 +21,56 @@ from __future__ import print_function import os import os.path import re +import sys from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging _TRACER_LOG_PREFIX = ' [>>>TT>>>]' _DEVICE_TYPE_TPU = 'tpu' _DEVICE_TYPE_CPU = 'cpu' -_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP' _TRACE_MODE_NAN_INF = 'nan-inf' _TRACE_MODE_PART_TENSOR = 'part-tensor' _TRACE_MODE_PART_TENSOR_SIZE = 3 _TRACE_MODE_FULL_TENSOR = 'full-tensor' -_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' -_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace' -_RECORD_FILTERED_OUT = 'not-traced-filtered-out' -_RECORD_SCALAR = 'not-traced-scalar' -_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' -_RECORD_GET_TRACED = 'get-traced' +_TRACE_MODE_NORM = 'norm' +_TRACE_MODE_MAX_ABS = 'max-abs' +_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' +_REASON_UNSAFE_OP = 'not-traced-unsafe-op' +_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' +_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' +_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' +_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' +_REASON_SCALAR_GET_TRACED = 'traced-scalar' +_REASON_TENSOR_GET_TRACED = 'traced-tensor' +_REASON_USER_INCLUDED = 'traced-user-included' +_REASON_USER_EXCLUDED = 'not-traced-user-excluded' +_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' _MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' _MARKER_SECTION_END = '!!!!!!! section-end:' _SECTION_NAME_CONFIG = 'configuration' _SECTION_NAME_REASON = 'reason' _SECTION_NAME_OP_LIST = 'op-list' +_SECTION_NAME_TENSOR_LIST = 'tensor-list' _SECTION_NAME_GRAPH = 'graph' _FIELD_NAME_VERSION = 'version:' _FIELD_NAME_DEVICE = 'device:' _FIELD_NAME_TRACE_MODE = 'trace-mode:' _FIELD_NAME_NUM_REPLICAS = 'num-replicas:' _FIELD_NAME_NUM_OPS = 'number-of-ops:' +_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' _FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' _FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' _FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") @@ -66,13 +78,72 @@ _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') _FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') _FLAG_NAME_ENABLE = 'enable' _FLAG_NAME_TRACE_MODE = 'trace_mode' -_FLAG_NAME_INTERESTING_OPS = 'interesting_ops' +_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' +_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' +_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' +_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' +_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' _FLAG_NAME_TRACE_FILE = 'trace_file_path' +_FLAG_NAME_REPORT_FILE = 'report_file_path' _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' _FLAG_NAME_OP_RANGE = 'op_range' _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') _OUTPUT_STREAM_ESCAPE = 'file://' _TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' +_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' +_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' + + +def tensor_checkpoint(tensor, checkpoint_name): + """Adds a checkpoint with the given checkpoint name for the given tensor. + + The tensor will be added to the list of tensors that will be traced by the + tensor tracer. + + Args: + tensor: the tensor object for which the tracing is requested. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + Returns: + The provided tensor. + """ + + tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) + tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, + (tensor, checkpoint_name)) + return tensor + + +def keras_layer_checkpoint(layer, checkpoint_name): + """An interface for adding the tensor outputs of a keras layer. + + Encapsulates tensor_checkpoint. + + Args: + layer: A keras layer. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + + Returns: + The provided layer. + """ + try: + outputs = layer.output + if tensor_util.is_tensor(outputs): + tensor_checkpoint(outputs, '%s' % (checkpoint_name)) + else: + idx = 0 + for output_tensor in outputs: + if tensor_util.is_tensor(outputs): + tensor_checkpoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) + idx += 1 + except AttributeError: + pass + except RuntimeError: + pass + return layer class TensorTracer(object): @@ -105,6 +176,34 @@ class TensorTracer(object): match = _FLAG_NO_QUOTE_PAT.match(flags, pos) return match + @staticmethod + def validate_flag_names(): + """Validates if the TensorTrace flags passed are valid.""" + valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, + _FLAG_NAME_EXCLUDED_OPNAMES, + _FLAG_NAME_EXCLUDED_OPTYPES, + _FLAG_NAME_INCLUDED_OPNAMES, + _FLAG_NAME_INCLUDED_OPTYPES, + _FLAG_NAME_TRACE_FILE, _FLAG_NAME_REPORT_FILE, + _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, + _FLAG_NAME_OP_RANGE] + tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + break + flag_name = match.group(1) + if flag_name not in valid_flag_names: + raise ValueError( + 'The flag name "%s" passed via the environment variable "%s" ' + 'is invalid. Valid flag names are:' + '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) + pos = match.end() + @staticmethod def print_flag_values(): """Prints all TensorTracer flags passed via environment variables.""" @@ -146,6 +245,20 @@ class TensorTracer(object): pos = match.end() return '' + @staticmethod + def flag_value_to_re_list(flag_name): + """Converts list of strings to compiled RE.""" + + re_list = [] + flag_value = TensorTracer.get_flag_value(flag_name) + if not flag_value: + return re_list + list_of_values = flag_value.split() + for v in list_of_values: + r = re.compile(v) + re_list.append(r) + return re_list + @staticmethod def is_enabled(): """Returns True if TensorTracer is enabled.""" @@ -186,29 +299,67 @@ class TensorTracer(object): """Checks if the given trace mode is valid.""" valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, - _TRACE_MODE_FULL_TENSOR] + _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM, + _TRACE_MODE_MAX_ABS] if trace_mode not in valid_trace_modes: raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 'Valid trace modes are: %s'%(trace_mode, valid_trace_modes)) @staticmethod - def should_trace(device_type, op): - """Returns True if the given Op should be traced.""" + def unsafe_op(op): + """Returns True if this op is not safe to be traced.""" - if device_type != _DEVICE_TYPE_TPU: - raise ValueError('Non TPU device type is not supported') if control_flow_util.IsInCond(op): + return True + # Reasons for not including following op types: + # Assign: cause incorrect result with CPU tracing. + # others: compilation problems. + if op.type in ['Assign', 'Pack', 'Shape', 'Reshape', 'ArgMin', 'ArgMax']: + return True + return False + + @staticmethod + def device_mismatch(device_type, op): + if device_type == _DEVICE_TYPE_TPU: + # pylint: disable=protected-access + return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr + # pylint: enable=protected-access + return False + + @staticmethod + def unsafe_scalar_trace(op): + """Return true if scalar output tensor from Op is not safe to be traced.""" + + # Tracing the following causes cycle in the graph on TPU. + if op.type in ['LoopCond', 'Enter', 'Merge', 'Const', + 'Switch', 'Less', 'ReadVariableOp']: + return True + # Tracing the following will cause casting-issue + # with the norm tracing mode or other compilation issues on CPU. + if op.type in ['VarHandleOp', 'IteratorToStringHandle', + 'IteratorGetNext', 'OneShotIterator', + 'IteratorV2', 'MakeIterator', + 'BatchDatasetV2', 'MapDataset', + 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', + 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']: + return True + return False + + @staticmethod + def less_interesting_op(op): + """Returns True if the given Op is not an interesting one to be traced.""" + + include_less_interesting = TensorTracer.get_flag_value( + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) + if include_less_interesting: return False - if op.type in ['Reshape', 'ArgMin', 'ArgMax']: - return False - # pylint: disable=protected-access - return tpu._TPU_REPLICATE_ATTR in op.node_def.attr - # pylint: enable=protected-access + return op.type in ['Const', 'Identity', 'Cast', 'Shape'] @staticmethod def reason(op_idx, details): - """Returns why the Op at op_idx is traced or not.""" + """Returns reason why the Op at op_idx is traced or not.""" + return '%d %s'%(op_idx, details) @staticmethod @@ -274,6 +425,33 @@ class TensorTracer(object): assert len(unsorted_ops) == len(sorted_ops) return (True, sorted_ops) + @staticmethod + def _make_op_and_tensor_maps(op_list): + """Creates various maps and lists from op_list. + + Args: + op_list: a list of Ops + + Returns: + opname_idx_map: a map from Op's name to its index in op_list. + tensor_list: a list of output tensors of the Ops in op_list. + tensorname_idx_map: a map from output tensor name to its index + in tensor_list. + """ + + opname_idx_map = {} + tensor_list = [] + tensorname_idx_map = {} + for op_id, op in enumerate(op_list): + if op.name in opname_idx_map: + raise ValueError('Duplicated Op name: %s'%op.name) + opname_idx_map[op.name] = op_id + for output_tensor in op.outputs: + if output_tensor.name not in tensorname_idx_map: + tensor_list.append(output_tensor) + tensorname_idx_map[output_tensor.name] = len(tensor_list)-1 + return (opname_idx_map, tensor_list, tensorname_idx_map) + def __init__(self): """Initializes a TensorTracer. @@ -281,16 +459,20 @@ class TensorTracer(object): """ self._version = 'use-outside-compilation' self._device_type = None + TensorTracer.validate_flag_names() self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) if not self._trace_mode: self._trace_mode = _TRACE_MODE_NAN_INF TensorTracer.check_trace_mode(self._trace_mode) self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE self._instrument_records = {} - interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS) - self._selected_ops = interesting_ops.split() self._set_trace_file_path() + self._set_report_file() self._set_op_range() + self._set_excluded_opnames() + self._set_excluded_optypes() + self._set_included_opnames() + self._set_included_optypes() self._num_replicas = None self._replica_id = None @@ -318,10 +500,7 @@ class TensorTracer(object): """Sets the path of the output trace file.""" self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) - if not self._trace_file_path: - raise ValueError('--%s is not set in the environment variable %s' - %(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR)) - elif TensorTracer.use_test_undeclared_outputs_dir(): + if self._trace_file_path and TensorTracer.use_test_undeclared_outputs_dir(): if os.path.isabs(self._trace_file_path): raise ValueError('If use_test_undeclared_outputs_dir is set,' 'trace_file_path cannot be an absolute path (%s)' @@ -330,6 +509,22 @@ class TensorTracer(object): self._trace_file_path = os.path.join(outputs_dir, self._trace_file_path) + def _set_report_file(self): + """Sets the path of the output report file.""" + + self._report_file_path = TensorTracer.get_flag_value(_FLAG_NAME_REPORT_FILE) + if not self._report_file_path: + self._report_file = None + return + try: + self._report_file = gfile.Open(self._report_file_path, 'w') + except IOError as e: + raise e + + def _close_report_file(self): + if self._report_file: + self._report_file.close() + def _set_op_range(self): """Sets the index range of the Ops that we will consider tracing.""" @@ -350,19 +545,48 @@ class TensorTracer(object): return False return self._op_range[1] < 0 or idx <= self._op_range[1] + def _set_excluded_opnames(self): + self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPNAMES) + + def _set_excluded_optypes(self): + self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPTYPES) + + def _set_included_opnames(self): + self._included_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPNAMES) + + def _set_included_optypes(self): + self._included_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPTYPES) + + def _is_user_included_op(self, op): + for opname_re in self._included_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._included_optype_re_list: + if optype_re.match(op.type): + return True + return False + + def _is_user_excluded_op(self, op): + for opname_re in self._excluded_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._excluded_optype_re_list: + if optype_re.match(op.type): + return True + return False + def _write_report(self, content): """Writes the given content to the report.""" - logging.info('%s %s'%(_TRACER_LOG_PREFIX, content)) - - def _is_selected_op(self, op_name): - """Returns True if the Op with op_name is selected to be traced.""" - - if not self._selected_ops: - return True - if op_name in self._selected_ops: - return True - return False + line = '%s %s'%(_TRACER_LOG_PREFIX, content) + if self._report_file: + self._report_file.write(line) + else: + logging.info(line) def _write_config_section(self): """Writes the config section of the report.""" @@ -382,15 +606,42 @@ class TensorTracer(object): self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) - def _write_op_list_section(self, op_list): + def _write_op_list_section(self, op_list, tensorname_idx_map): """Writes the Op-list section of the report.""" self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) for i in range(0, len(op_list)): - self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type)) + op = op_list[i] + line = '%d "%s" %s'%(i, op.name, op.type) + for out_tensor in op.outputs: + if out_tensor.name not in tensorname_idx_map: + raise ValueError( + 'out_tensor %s is not in tensorname_idx_map'%out_tensor.name) + line += ' %d'%tensorname_idx_map[out_tensor.name] + line += '\n' + self._write_report(line) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) + def _write_tensor_list_section(self, tensor_list, opname_idx_map): + """Writes the tensor-list section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _SECTION_NAME_TENSOR_LIST)) + self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list))) + for i in range(0, len(tensor_list)): + tensor = tensor_list[i] + line = '%d "%s"'%(i, tensor.name) + for consumer_op in tensor.consumers(): + if consumer_op.name not in opname_idx_map: + raise ValueError( + 'consumer_op %s is not in opname_idx_map'%consumer_op.name) + line += ' %d'%opname_idx_map[consumer_op.name] + line += '\n' + self._write_report(line) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _SECTION_NAME_TENSOR_LIST)) + def _write_graph_section(self, succeed, sorted_or_cycle): """Writes the graph section of the report.""" @@ -422,7 +673,7 @@ class TensorTracer(object): Args: op_name: the name of the Op that outputs the tensor to be printed. output_idx: which output of the Op it is (0 means the first output). - num_elements: number of elements to print. + num_elements: number of elements to print (-1 means print all). tensor: the tensor needs to be returned. output_tensor: the tensor needs to be printed. @@ -430,10 +681,13 @@ class TensorTracer(object): The same tensor passed via the "tensor" argument. """ msg = '"%s:%d" '%(op_name, output_idx) - output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + if self._trace_file_path: + output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + else: + output_stream = sys.stderr print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), ' @', self._replica_id, - '\n', output_tensor, + '\n', output_tensor, '\n', summarize=num_elements, output_stream=output_stream) with ops.control_dependencies([print_op]): @@ -442,7 +696,8 @@ class TensorTracer(object): def _detect_nan_inf(tensor): """Trace function for detecting any NaN/Inf in the tensor.""" - if tensor.dtype.is_floating: + if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( + dtypes.float16): # Since host can't handle bf16, always convert tensor to f32. tensor = math_ops.cast(tensor, dtypes.float32) output_tensor = math_ops.reduce_any( @@ -450,12 +705,19 @@ class TensorTracer(object): gen_math_ops.is_inf(tensor))) else: output_tensor = constant_op.constant(0) - return _print_tensor(op_name, output_idx, 1, tensor, output_tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - def _show_global_step(tensor): - """Trace function for printing the global step count.""" + def _show_norm(tensor): + tensor = math_ops.cast(tensor, dtypes.float64) + output_tensor = linalg_ops.norm(tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) - return _print_tensor(op_name, output_idx, 1, tensor, tensor) + def _show_max_abs(tensor): + output_tensor = math_ops.cast(math_ops.reduce_max(math_ops.abs(tensor)), + dtypes.float64) + zero = constant_op.constant(0, dtypes.float64) + output_tensor = gen_math_ops.maximum(zero, output_tensor) + return _print_tensor(op_name, output_idx, -1, tensor, output_tensor) def _show_part_tensor(tensor): """Trace function for printing part of the tensor.""" @@ -468,23 +730,139 @@ class TensorTracer(object): return _print_tensor(op_name, output_idx, -1, tensor, tensor) - if op_name == _GLOBAL_STEP_OP_NAME: - return _show_global_step if self._trace_mode == _TRACE_MODE_NAN_INF: return _detect_nan_inf if self._trace_mode == _TRACE_MODE_PART_TENSOR: return _show_part_tensor if self._trace_mode == _TRACE_MODE_FULL_TENSOR: return _show_full_tensor + if self._trace_mode == _TRACE_MODE_NORM: + return _show_norm + if self._trace_mode == _TRACE_MODE_MAX_ABS: + return _show_max_abs raise RuntimeError('Tensor trace fun for %s is not yet implemented' %self._trace_mode) + def _skip_op(self, op_id, op, user_included, user_excluded): + """Returns True if we should not trace Op.""" + + if user_included: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not self._inside_op_range(op_id): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_OUTSIDE_OP_RANGE) + return True + if TensorTracer.unsafe_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_OP) + return True + if TensorTracer.device_mismatch(self._device_type, op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_DEVICE_MISMATCH) + return True + if TensorTracer.less_interesting_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_LESS_INTERESTING_OP) + return True + return False + + def _skip_tensor(self, op_id, out_tensor, user_included, + user_excluded): + """Returns True if we should not trace out_tensor.""" + + # Skips a tensor if the tensor has a non-numeric type. + # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) + # because it also excludes tensors with dtypes, bool, and + # float32_ref, which we actually want to trace. + non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, + dtypes.string]) + if out_tensor.dtype in non_numeric_tensor_types: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_NON_NUMERIC_TENSOR) + return True + + if user_included: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not out_tensor.get_shape().is_fully_defined(): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_DYNAMIC_SHAPE) + return True + rank = len(out_tensor.shape) + if rank < 1: + # scalar + if TensorTracer.unsafe_scalar_trace(out_tensor.op): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_SCALAR) + return True + else: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_SCALAR_GET_TRACED) + return False + else: + # tensor + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_TENSOR_GET_TRACED) + return False + + def _pre_tracing(self, graph): + """Work needs to be done prior to TPU or CPU tracing.""" + + operations = graph.get_operations() + (opname_idx_map, tensor_list, tensorname_idx_map) = ( + TensorTracer._make_op_and_tensor_maps(operations)) + self._write_config_section() + self._write_op_list_section(operations, tensorname_idx_map) + self._write_tensor_list_section(tensor_list, opname_idx_map) + # Does the topological sort before adding any nodes to the graph. + (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + return (operations, succeed, sorted_or_cycle) + + def _post_tracing(self, succeed, sorted_or_cycle): + """Work needs to be done after TPU or CPU tracing.""" + + self._write_reason_section() + self._write_graph_section(succeed, sorted_or_cycle) + self._close_report_file() + + def _get_checkpoints(self, graph): + """Returns the list of Ops that produce the tensors traced with API. + + Args: + graph: the graph of Ops. + + Returns: + A set of operation names which should be traced. + """ + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _TENSOR_TRACER_CHECKPOINT)) + checkpoint_operations = set() + tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION) + for (tensor, checkpoint_name) in tensor_tracer_variables: + self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) + checkpoint_operations.add(tensor.op.name) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _TENSOR_TRACER_CHECKPOINT)) + return checkpoint_operations + def trace_tpu(self, graph, result_tensor, num_replicas=None): """Traces the tensors generated by TPU Ops in a TF graph. Args: - graph: the graph of Ops. + graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. num_replicas: number of replicas used on the TPU. @@ -502,38 +880,22 @@ class TensorTracer(object): TensorTracer.check_device_type(self._device_type) result_tensor_copy = self._add_replica_id_to_graph(num_replicas, result_tensor) - self._write_config_section() + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) tracing_ops = [] - operations = graph.get_operations() - self._write_op_list_section(operations) - # Does the topological sort before adding any nodes to the graph. - (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + checkpoint_operations = self._get_checkpoints(graph) + for op_id, op in enumerate(operations): - if not self._inside_op_range(op_id): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_OUTSIDE_OP_RANGE) + if checkpoint_operations and op.name not in checkpoint_operations: continue - if not TensorTracer.should_trace(self._device_type, op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_SHOULD_NOT_TRACE) - continue - if not self._is_selected_op(op.name): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_FILTERED_OUT) + user_included = self._is_user_included_op(op) + user_excluded = self._is_user_excluded_op(op) + if self._skip_op(op_id, op, user_included, user_excluded): continue for i in range(len(op.outputs)): out_tensor = op.outputs[i] - if not out_tensor.get_shape().is_fully_defined(): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_DYNAMIC_SHAPE) - continue # cannot trace tensors with dynamic shape. - rank = len(out_tensor.shape) - if rank < 1: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_SCALAR) - continue # cannot trace scalar. - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_GET_TRACED) + if self._skip_tensor(op_id, out_tensor, user_included, + user_excluded): + continue consumers = out_tensor.consumers() trace_op = tpu.outside_compilation( self._make_tensor_trace_fun(op.name, i), out_tensor) @@ -546,8 +908,45 @@ class TensorTracer(object): # if there is no consumer, we will add the control dependence later # when we add the control dependency to the output operations. tracing_ops.append(trace_op) - - self._write_reason_section() - self._write_graph_section(succeed, sorted_or_cycle) - + self._post_tracing(succeed, sorted_or_cycle) return (result_tensor_copy, tracing_ops) + + def trace_cpu(self, graph): + """Traces the tensors generated by CPU Ops in a TF graph. + + Args: + graph: the graph of Ops executed on the CPU. + + Returns: + tracing_calls: a map from keys to trace calls. + A key is constructed from an Op's name. + A trace call consists of a function and a tensor ( + the function will be invoked with the tensor). + """ + + self._device_type = _DEVICE_TYPE_CPU + TensorTracer.check_device_type(self._device_type) + self._num_replicas = 1 + self._replica_id = 0 + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph) + tracing_calls = {} + checkpoint_operations = self._get_checkpoints(graph) + + for op_id, op in enumerate(operations): + if checkpoint_operations and op.name not in checkpoint_operations: + continue + user_included = self._is_user_included_op(op) + user_excluded = self._is_user_excluded_op(op) + if self._skip_op(op_id, op, user_included, user_excluded): + continue + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + if self._skip_tensor(op_id, out_tensor, user_included, + user_excluded): + continue + trace_fun = self._make_tensor_trace_fun(op.name, i) + trace_call = (trace_fun, [out_tensor]) + trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i) + tracing_calls[trace_call_key] = trace_call + self._post_tracing(succeed, sorted_or_cycle) + return tracing_calls diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 84816d70d08..fe2ac61bf91 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -336,6 +336,16 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] + if tensor_tracer.TensorTracer.is_enabled(): + tt = tensor_tracer.TensorTracer() + tracing_calls = tt.trace_cpu(ops.get_default_graph()) + tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls) + tracing_functions = tracing_call_ret.values() + if tracing_functions: + if hooks: + hooks.extend([_OutfeedHostCallHook(tracing_functions)]) + else: + hooks = [_OutfeedHostCallHook(tracing_functions)] hooks = tuple(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec(