Fix style in `op_hint.py` to match formatting from Copybara.
No functional changes PiperOrigin-RevId: 311566454 Change-Id: Ic4f002df42168bdb8841b80a93ebf22a8e7fa4bd
This commit is contained in:
parent
8098b12009
commit
0d94bc6d71
|
@ -435,6 +435,7 @@ class OpHint(object):
|
|||
Args:
|
||||
*args: List of inputs to be converted (should be Tf.Tensor).
|
||||
**kwargs: This allows 'names' which should be a list of names.
|
||||
|
||||
Returns:
|
||||
Wrapped inputs (identity standins that have additional metadata). These
|
||||
are also are also tf.Tensor's.
|
||||
|
@ -453,6 +454,7 @@ class OpHint(object):
|
|||
Args:
|
||||
*args: List of outputs to be converted (should be tf.Tensor).
|
||||
**kwargs: See
|
||||
|
||||
Returns:
|
||||
Wrapped outputs (identity standins that have additional metadata). These
|
||||
are also tf.Tensor's.
|
||||
|
@ -574,8 +576,8 @@ class _LiteAggregateOperand(_LiteOperand):
|
|||
elif self.aggregation == OpHint.AGGREGATE_STACK:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid aggregation type %r specified" % self.aggregation)
|
||||
raise ValueError("Invalid aggregation type %r specified" %
|
||||
self.aggregation)
|
||||
return self.flattened
|
||||
|
||||
def flatten(self):
|
||||
|
@ -646,8 +648,8 @@ class _LiteAggregateOperand(_LiteOperand):
|
|||
stack_node.attr["num"].i = len(flattened)
|
||||
output_type = flattened[0].attr["T"].type
|
||||
stack_node.attr["T"].type = output_type
|
||||
stack_node.input.append(_tensorflow_output_name(
|
||||
fused_op_name, output_index))
|
||||
stack_node.input.append(
|
||||
_tensorflow_output_name(fused_op_name, output_index))
|
||||
out_graphdef.node.extend([stack_node])
|
||||
|
||||
for idx, discrete in enumerate(flattened):
|
||||
|
@ -675,11 +677,10 @@ class _LiteFuncCall(object):
|
|||
inputs: inputs to the op (hash from index # to argument)
|
||||
outputs: outputs to the op (hash from index # to argument)
|
||||
function_name: the tflite custom op name to use
|
||||
uuid: a unique call id for this particular call (i.e.
|
||||
multiple function calls would have the same function_name but different
|
||||
uuids.
|
||||
params: A param name to key value for op constant data. I.e. for
|
||||
axis on a reduction, strides on a convolution, etc.
|
||||
uuid: a unique call id for this particular call (i.e. multiple function
|
||||
calls would have the same function_name but different uuids.
|
||||
params: A param name to key value for op constant data. I.e. for axis on a
|
||||
reduction, strides on a convolution, etc.
|
||||
level: Level of the OpHint.
|
||||
children_inputs_mappings: If the Ophint has children, children inputs
|
||||
mappings indicate how their inputs & outputs are mapped.
|
||||
|
@ -700,6 +701,7 @@ class _LiteFuncCall(object):
|
|||
Returns:
|
||||
Tuple of (inputs, outputs). where input and output i a list of names.
|
||||
"""
|
||||
|
||||
def _flatten(input_or_output_dict):
|
||||
flattened_items = []
|
||||
for item in input_or_output_dict.values():
|
||||
|
@ -709,6 +711,7 @@ class _LiteFuncCall(object):
|
|||
return _flatten(self.inputs), _flatten(self.outputs)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
def format_args(items):
|
||||
s = ""
|
||||
for idx, item in items.iteritems():
|
||||
|
@ -739,8 +742,8 @@ def _find_all_hints_in_nodes(nodes):
|
|||
for node in nodes:
|
||||
attr = node.attr
|
||||
# This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
|
||||
if (OpHint.FUNCTION_UUID_ATTR not in attr
|
||||
or not attr[OpHint.FUNCTION_UUID_ATTR].s):
|
||||
if (OpHint.FUNCTION_UUID_ATTR not in attr or
|
||||
not attr[OpHint.FUNCTION_UUID_ATTR].s):
|
||||
continue
|
||||
uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
|
||||
|
||||
|
@ -751,9 +754,11 @@ def _find_all_hints_in_nodes(nodes):
|
|||
call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
|
||||
# Get sorting and aggregation information
|
||||
|
||||
sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
|
||||
sort = (
|
||||
attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
|
||||
if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
|
||||
if sort == -1: sort = None
|
||||
if sort == -1:
|
||||
sort = None
|
||||
aggregation = None
|
||||
if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
|
||||
aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
|
||||
|
@ -887,6 +892,7 @@ def _tensor_name_base(full_tensor_name):
|
|||
Args:
|
||||
full_tensor_name: A tensor name that is annotated with a device placement
|
||||
(this is what tensor flow introspection gives).
|
||||
|
||||
Returns:
|
||||
A name without any device assignment.
|
||||
"""
|
||||
|
@ -919,10 +925,10 @@ def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
|
|||
while next_to_visit:
|
||||
current_node = next_to_visit.pop()
|
||||
visited.add(current_node)
|
||||
if (current_node in reachable_by_input
|
||||
and current_node not in input_nodes_set):
|
||||
raise TypeError(
|
||||
"Node %s uses input %s not in input_nodes." % (n, current_node))
|
||||
if (current_node in reachable_by_input and
|
||||
current_node not in input_nodes_set):
|
||||
raise TypeError("Node %s uses input %s not in input_nodes." %
|
||||
(n, current_node))
|
||||
if current_node not in input_nodes_set:
|
||||
next_to_visit += [
|
||||
input_node for input_node in name_to_input_name[current_node]
|
||||
|
@ -1066,6 +1072,7 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
|
|||
|
||||
Args:
|
||||
in_graph_def: Graph def to use as input.
|
||||
|
||||
Returns:
|
||||
Simplified tuple (graph_def, changed_something) where changed_something
|
||||
is true if anything was done.
|
||||
|
@ -1101,15 +1108,15 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
|
|||
node = name_to_node[current_node_name]
|
||||
is_op_hint_stack = node.name.startswith("OpHintStack")
|
||||
is_op_hint_unstack = node.name.startswith("OpHintUnstack")
|
||||
if (node.op == "Identity" or is_op_hint_stack
|
||||
or (do_generic_pack_unpack and node.op == "Pack")):
|
||||
if (node.op == "Identity" or is_op_hint_stack or
|
||||
(do_generic_pack_unpack and node.op == "Pack")):
|
||||
is_hint_created_stack |= is_op_hint_stack
|
||||
next_to_visit += [
|
||||
input_node for input_node in name_to_input_name[current_node_name]
|
||||
if input_node not in visited
|
||||
]
|
||||
elif (is_op_hint_unstack
|
||||
or (do_generic_pack_unpack and node.op == "Unpack")):
|
||||
elif (is_op_hint_unstack or
|
||||
(do_generic_pack_unpack and node.op == "Unpack")):
|
||||
unpack_nodes.add(node.name)
|
||||
is_hint_created_stack &= is_op_hint_unstack
|
||||
else:
|
||||
|
@ -1124,7 +1131,8 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
|
|||
# Unstacked form
|
||||
no_external_dependency = True
|
||||
for other_n in in_graph_def.node:
|
||||
if other_n.name in visited: continue
|
||||
if other_n.name in visited:
|
||||
continue
|
||||
for input_tensor in name_to_input_name[other_n.name]:
|
||||
input_op = _tensor_name_base(input_tensor)
|
||||
if input_op in visited and input_op != pack_node:
|
||||
|
@ -1141,9 +1149,9 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
|
|||
if node_name not in visited:
|
||||
new_node = _copy.deepcopy(other_n)
|
||||
new_node.input[:] = [
|
||||
(end_input if stripped == pack_node else
|
||||
non_stripped) for stripped, non_stripped in zip(
|
||||
name_to_input_name[node_name], new_node.input[:])
|
||||
(end_input if stripped == pack_node else non_stripped)
|
||||
for stripped, non_stripped in zip(name_to_input_name[node_name],
|
||||
new_node.input[:])
|
||||
]
|
||||
out.node.extend([new_node])
|
||||
return out, True
|
||||
|
@ -1177,6 +1185,7 @@ def _convert_op_hints_to_stubs_helper(
|
|||
graph_def: A graph def that we should convert.
|
||||
write_callback: A function pointer that can be used to write intermediate
|
||||
steps of graph transformation (optional).
|
||||
|
||||
Returns:
|
||||
A new stubbed graph_def.
|
||||
"""
|
||||
|
@ -1306,6 +1315,7 @@ def convert_op_hints_to_stubs(session=None,
|
|||
graph_def: A graph def that we should convert.
|
||||
write_callback: A function pointer that can be used to write intermediate
|
||||
steps of graph transformation (optional).
|
||||
|
||||
Returns:
|
||||
A new graphdef with all ops contained in OpHints being replaced by
|
||||
a single op call with the right parameters.
|
||||
|
|
Loading…
Reference in New Issue