diff --git a/tensorflow/lite/python/op_hint.py b/tensorflow/lite/python/op_hint.py index 159fcaa2bf3..9d62c1b8a97 100644 --- a/tensorflow/lite/python/op_hint.py +++ b/tensorflow/lite/python/op_hint.py @@ -435,6 +435,7 @@ class OpHint(object): Args: *args: List of inputs to be converted (should be Tf.Tensor). **kwargs: This allows 'names' which should be a list of names. + Returns: Wrapped inputs (identity standins that have additional metadata). These are also are also tf.Tensor's. @@ -453,6 +454,7 @@ class OpHint(object): Args: *args: List of outputs to be converted (should be tf.Tensor). **kwargs: See + Returns: Wrapped outputs (identity standins that have additional metadata). These are also tf.Tensor's. @@ -574,8 +576,8 @@ class _LiteAggregateOperand(_LiteOperand): elif self.aggregation == OpHint.AGGREGATE_STACK: pass else: - raise ValueError( - "Invalid aggregation type %r specified" % self.aggregation) + raise ValueError("Invalid aggregation type %r specified" % + self.aggregation) return self.flattened def flatten(self): @@ -646,8 +648,8 @@ class _LiteAggregateOperand(_LiteOperand): stack_node.attr["num"].i = len(flattened) output_type = flattened[0].attr["T"].type stack_node.attr["T"].type = output_type - stack_node.input.append(_tensorflow_output_name( - fused_op_name, output_index)) + stack_node.input.append( + _tensorflow_output_name(fused_op_name, output_index)) out_graphdef.node.extend([stack_node]) for idx, discrete in enumerate(flattened): @@ -675,11 +677,10 @@ class _LiteFuncCall(object): inputs: inputs to the op (hash from index # to argument) outputs: outputs to the op (hash from index # to argument) function_name: the tflite custom op name to use - uuid: a unique call id for this particular call (i.e. - multiple function calls would have the same function_name but different - uuids. - params: A param name to key value for op constant data. I.e. for - axis on a reduction, strides on a convolution, etc. + uuid: a unique call id for this particular call (i.e. multiple function + calls would have the same function_name but different uuids. + params: A param name to key value for op constant data. I.e. for axis on a + reduction, strides on a convolution, etc. level: Level of the OpHint. children_inputs_mappings: If the Ophint has children, children inputs mappings indicate how their inputs & outputs are mapped. @@ -700,6 +701,7 @@ class _LiteFuncCall(object): Returns: Tuple of (inputs, outputs). where input and output i a list of names. """ + def _flatten(input_or_output_dict): flattened_items = [] for item in input_or_output_dict.values(): @@ -709,6 +711,7 @@ class _LiteFuncCall(object): return _flatten(self.inputs), _flatten(self.outputs) def __str__(self): + def format_args(items): s = "" for idx, item in items.iteritems(): @@ -739,8 +742,8 @@ def _find_all_hints_in_nodes(nodes): for node in nodes: attr = node.attr # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip - if (OpHint.FUNCTION_UUID_ATTR not in attr - or not attr[OpHint.FUNCTION_UUID_ATTR].s): + if (OpHint.FUNCTION_UUID_ATTR not in attr or + not attr[OpHint.FUNCTION_UUID_ATTR].s): continue uuid = attr[OpHint.FUNCTION_UUID_ATTR].s @@ -751,9 +754,11 @@ def _find_all_hints_in_nodes(nodes): call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i # Get sorting and aggregation information - sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i - if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) - if sort == -1: sort = None + sort = ( + attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i + if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) + if sort == -1: + sort = None aggregation = None if OpHint.FUNCTION_AGGREGATE_ATTR in attr: aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) @@ -887,6 +892,7 @@ def _tensor_name_base(full_tensor_name): Args: full_tensor_name: A tensor name that is annotated with a device placement (this is what tensor flow introspection gives). + Returns: A name without any device assignment. """ @@ -919,10 +925,10 @@ def _check_subgraph_closed(n, reachable_by_input, input_nodes_set, while next_to_visit: current_node = next_to_visit.pop() visited.add(current_node) - if (current_node in reachable_by_input - and current_node not in input_nodes_set): - raise TypeError( - "Node %s uses input %s not in input_nodes." % (n, current_node)) + if (current_node in reachable_by_input and + current_node not in input_nodes_set): + raise TypeError("Node %s uses input %s not in input_nodes." % + (n, current_node)) if current_node not in input_nodes_set: next_to_visit += [ input_node for input_node in name_to_input_name[current_node] @@ -1066,6 +1072,7 @@ def _remove_one_redundant_stack_unstack(in_graph_def): Args: in_graph_def: Graph def to use as input. + Returns: Simplified tuple (graph_def, changed_something) where changed_something is true if anything was done. @@ -1101,15 +1108,15 @@ def _remove_one_redundant_stack_unstack(in_graph_def): node = name_to_node[current_node_name] is_op_hint_stack = node.name.startswith("OpHintStack") is_op_hint_unstack = node.name.startswith("OpHintUnstack") - if (node.op == "Identity" or is_op_hint_stack - or (do_generic_pack_unpack and node.op == "Pack")): + if (node.op == "Identity" or is_op_hint_stack or + (do_generic_pack_unpack and node.op == "Pack")): is_hint_created_stack |= is_op_hint_stack next_to_visit += [ input_node for input_node in name_to_input_name[current_node_name] if input_node not in visited ] - elif (is_op_hint_unstack - or (do_generic_pack_unpack and node.op == "Unpack")): + elif (is_op_hint_unstack or + (do_generic_pack_unpack and node.op == "Unpack")): unpack_nodes.add(node.name) is_hint_created_stack &= is_op_hint_unstack else: @@ -1124,7 +1131,8 @@ def _remove_one_redundant_stack_unstack(in_graph_def): # Unstacked form no_external_dependency = True for other_n in in_graph_def.node: - if other_n.name in visited: continue + if other_n.name in visited: + continue for input_tensor in name_to_input_name[other_n.name]: input_op = _tensor_name_base(input_tensor) if input_op in visited and input_op != pack_node: @@ -1141,9 +1149,9 @@ def _remove_one_redundant_stack_unstack(in_graph_def): if node_name not in visited: new_node = _copy.deepcopy(other_n) new_node.input[:] = [ - (end_input if stripped == pack_node else - non_stripped) for stripped, non_stripped in zip( - name_to_input_name[node_name], new_node.input[:]) + (end_input if stripped == pack_node else non_stripped) + for stripped, non_stripped in zip(name_to_input_name[node_name], + new_node.input[:]) ] out.node.extend([new_node]) return out, True @@ -1177,6 +1185,7 @@ def _convert_op_hints_to_stubs_helper( graph_def: A graph def that we should convert. write_callback: A function pointer that can be used to write intermediate steps of graph transformation (optional). + Returns: A new stubbed graph_def. """ @@ -1306,6 +1315,7 @@ def convert_op_hints_to_stubs(session=None, graph_def: A graph def that we should convert. write_callback: A function pointer that can be used to write intermediate steps of graph transformation (optional). + Returns: A new graphdef with all ops contained in OpHints being replaced by a single op call with the right parameters.