Fix style in `op_hint.py` to match formatting from Copybara.

No functional changes

PiperOrigin-RevId: 311566454
Change-Id: Ic4f002df42168bdb8841b80a93ebf22a8e7fa4bd
This commit is contained in:
Mihai Maruseac 2020-05-14 11:02:42 -07:00 committed by TensorFlower Gardener
parent 8098b12009
commit 0d94bc6d71
1 changed files with 36 additions and 26 deletions

View File

@ -435,6 +435,7 @@ class OpHint(object):
Args:
*args: List of inputs to be converted (should be Tf.Tensor).
**kwargs: This allows 'names' which should be a list of names.
Returns:
Wrapped inputs (identity standins that have additional metadata). These
are also are also tf.Tensor's.
@ -453,6 +454,7 @@ class OpHint(object):
Args:
*args: List of outputs to be converted (should be tf.Tensor).
**kwargs: See
Returns:
Wrapped outputs (identity standins that have additional metadata). These
are also tf.Tensor's.
@ -574,8 +576,8 @@ class _LiteAggregateOperand(_LiteOperand):
elif self.aggregation == OpHint.AGGREGATE_STACK:
pass
else:
raise ValueError(
"Invalid aggregation type %r specified" % self.aggregation)
raise ValueError("Invalid aggregation type %r specified" %
self.aggregation)
return self.flattened
def flatten(self):
@ -646,8 +648,8 @@ class _LiteAggregateOperand(_LiteOperand):
stack_node.attr["num"].i = len(flattened)
output_type = flattened[0].attr["T"].type
stack_node.attr["T"].type = output_type
stack_node.input.append(_tensorflow_output_name(
fused_op_name, output_index))
stack_node.input.append(
_tensorflow_output_name(fused_op_name, output_index))
out_graphdef.node.extend([stack_node])
for idx, discrete in enumerate(flattened):
@ -675,11 +677,10 @@ class _LiteFuncCall(object):
inputs: inputs to the op (hash from index # to argument)
outputs: outputs to the op (hash from index # to argument)
function_name: the tflite custom op name to use
uuid: a unique call id for this particular call (i.e.
multiple function calls would have the same function_name but different
uuids.
params: A param name to key value for op constant data. I.e. for
axis on a reduction, strides on a convolution, etc.
uuid: a unique call id for this particular call (i.e. multiple function
calls would have the same function_name but different uuids.
params: A param name to key value for op constant data. I.e. for axis on a
reduction, strides on a convolution, etc.
level: Level of the OpHint.
children_inputs_mappings: If the Ophint has children, children inputs
mappings indicate how their inputs & outputs are mapped.
@ -700,6 +701,7 @@ class _LiteFuncCall(object):
Returns:
Tuple of (inputs, outputs). where input and output i a list of names.
"""
def _flatten(input_or_output_dict):
flattened_items = []
for item in input_or_output_dict.values():
@ -709,6 +711,7 @@ class _LiteFuncCall(object):
return _flatten(self.inputs), _flatten(self.outputs)
def __str__(self):
def format_args(items):
s = ""
for idx, item in items.iteritems():
@ -739,8 +742,8 @@ def _find_all_hints_in_nodes(nodes):
for node in nodes:
attr = node.attr
# This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
if (OpHint.FUNCTION_UUID_ATTR not in attr
or not attr[OpHint.FUNCTION_UUID_ATTR].s):
if (OpHint.FUNCTION_UUID_ATTR not in attr or
not attr[OpHint.FUNCTION_UUID_ATTR].s):
continue
uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
@ -751,9 +754,11 @@ def _find_all_hints_in_nodes(nodes):
call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
# Get sorting and aggregation information
sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
sort = (
attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
if sort == -1: sort = None
if sort == -1:
sort = None
aggregation = None
if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
@ -887,6 +892,7 @@ def _tensor_name_base(full_tensor_name):
Args:
full_tensor_name: A tensor name that is annotated with a device placement
(this is what tensor flow introspection gives).
Returns:
A name without any device assignment.
"""
@ -919,10 +925,10 @@ def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
while next_to_visit:
current_node = next_to_visit.pop()
visited.add(current_node)
if (current_node in reachable_by_input
and current_node not in input_nodes_set):
raise TypeError(
"Node %s uses input %s not in input_nodes." % (n, current_node))
if (current_node in reachable_by_input and
current_node not in input_nodes_set):
raise TypeError("Node %s uses input %s not in input_nodes." %
(n, current_node))
if current_node not in input_nodes_set:
next_to_visit += [
input_node for input_node in name_to_input_name[current_node]
@ -1066,6 +1072,7 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
Args:
in_graph_def: Graph def to use as input.
Returns:
Simplified tuple (graph_def, changed_something) where changed_something
is true if anything was done.
@ -1101,15 +1108,15 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
node = name_to_node[current_node_name]
is_op_hint_stack = node.name.startswith("OpHintStack")
is_op_hint_unstack = node.name.startswith("OpHintUnstack")
if (node.op == "Identity" or is_op_hint_stack
or (do_generic_pack_unpack and node.op == "Pack")):
if (node.op == "Identity" or is_op_hint_stack or
(do_generic_pack_unpack and node.op == "Pack")):
is_hint_created_stack |= is_op_hint_stack
next_to_visit += [
input_node for input_node in name_to_input_name[current_node_name]
if input_node not in visited
]
elif (is_op_hint_unstack
or (do_generic_pack_unpack and node.op == "Unpack")):
elif (is_op_hint_unstack or
(do_generic_pack_unpack and node.op == "Unpack")):
unpack_nodes.add(node.name)
is_hint_created_stack &= is_op_hint_unstack
else:
@ -1124,7 +1131,8 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
# Unstacked form
no_external_dependency = True
for other_n in in_graph_def.node:
if other_n.name in visited: continue
if other_n.name in visited:
continue
for input_tensor in name_to_input_name[other_n.name]:
input_op = _tensor_name_base(input_tensor)
if input_op in visited and input_op != pack_node:
@ -1141,9 +1149,9 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
if node_name not in visited:
new_node = _copy.deepcopy(other_n)
new_node.input[:] = [
(end_input if stripped == pack_node else
non_stripped) for stripped, non_stripped in zip(
name_to_input_name[node_name], new_node.input[:])
(end_input if stripped == pack_node else non_stripped)
for stripped, non_stripped in zip(name_to_input_name[node_name],
new_node.input[:])
]
out.node.extend([new_node])
return out, True
@ -1177,6 +1185,7 @@ def _convert_op_hints_to_stubs_helper(
graph_def: A graph def that we should convert.
write_callback: A function pointer that can be used to write intermediate
steps of graph transformation (optional).
Returns:
A new stubbed graph_def.
"""
@ -1306,6 +1315,7 @@ def convert_op_hints_to_stubs(session=None,
graph_def: A graph def that we should convert.
write_callback: A function pointer that can be used to write intermediate
steps of graph transformation (optional).
Returns:
A new graphdef with all ops contained in OpHints being replaced by
a single op call with the right parameters.