Fix style in `op_hint.py` to match formatting from Copybara.
No functional changes PiperOrigin-RevId: 311566454 Change-Id: Ic4f002df42168bdb8841b80a93ebf22a8e7fa4bd
This commit is contained in:
parent
8098b12009
commit
0d94bc6d71
|
@ -435,6 +435,7 @@ class OpHint(object):
|
||||||
Args:
|
Args:
|
||||||
*args: List of inputs to be converted (should be Tf.Tensor).
|
*args: List of inputs to be converted (should be Tf.Tensor).
|
||||||
**kwargs: This allows 'names' which should be a list of names.
|
**kwargs: This allows 'names' which should be a list of names.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Wrapped inputs (identity standins that have additional metadata). These
|
Wrapped inputs (identity standins that have additional metadata). These
|
||||||
are also are also tf.Tensor's.
|
are also are also tf.Tensor's.
|
||||||
|
@ -453,6 +454,7 @@ class OpHint(object):
|
||||||
Args:
|
Args:
|
||||||
*args: List of outputs to be converted (should be tf.Tensor).
|
*args: List of outputs to be converted (should be tf.Tensor).
|
||||||
**kwargs: See
|
**kwargs: See
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Wrapped outputs (identity standins that have additional metadata). These
|
Wrapped outputs (identity standins that have additional metadata). These
|
||||||
are also tf.Tensor's.
|
are also tf.Tensor's.
|
||||||
|
@ -574,8 +576,8 @@ class _LiteAggregateOperand(_LiteOperand):
|
||||||
elif self.aggregation == OpHint.AGGREGATE_STACK:
|
elif self.aggregation == OpHint.AGGREGATE_STACK:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("Invalid aggregation type %r specified" %
|
||||||
"Invalid aggregation type %r specified" % self.aggregation)
|
self.aggregation)
|
||||||
return self.flattened
|
return self.flattened
|
||||||
|
|
||||||
def flatten(self):
|
def flatten(self):
|
||||||
|
@ -646,8 +648,8 @@ class _LiteAggregateOperand(_LiteOperand):
|
||||||
stack_node.attr["num"].i = len(flattened)
|
stack_node.attr["num"].i = len(flattened)
|
||||||
output_type = flattened[0].attr["T"].type
|
output_type = flattened[0].attr["T"].type
|
||||||
stack_node.attr["T"].type = output_type
|
stack_node.attr["T"].type = output_type
|
||||||
stack_node.input.append(_tensorflow_output_name(
|
stack_node.input.append(
|
||||||
fused_op_name, output_index))
|
_tensorflow_output_name(fused_op_name, output_index))
|
||||||
out_graphdef.node.extend([stack_node])
|
out_graphdef.node.extend([stack_node])
|
||||||
|
|
||||||
for idx, discrete in enumerate(flattened):
|
for idx, discrete in enumerate(flattened):
|
||||||
|
@ -675,11 +677,10 @@ class _LiteFuncCall(object):
|
||||||
inputs: inputs to the op (hash from index # to argument)
|
inputs: inputs to the op (hash from index # to argument)
|
||||||
outputs: outputs to the op (hash from index # to argument)
|
outputs: outputs to the op (hash from index # to argument)
|
||||||
function_name: the tflite custom op name to use
|
function_name: the tflite custom op name to use
|
||||||
uuid: a unique call id for this particular call (i.e.
|
uuid: a unique call id for this particular call (i.e. multiple function
|
||||||
multiple function calls would have the same function_name but different
|
calls would have the same function_name but different uuids.
|
||||||
uuids.
|
params: A param name to key value for op constant data. I.e. for axis on a
|
||||||
params: A param name to key value for op constant data. I.e. for
|
reduction, strides on a convolution, etc.
|
||||||
axis on a reduction, strides on a convolution, etc.
|
|
||||||
level: Level of the OpHint.
|
level: Level of the OpHint.
|
||||||
children_inputs_mappings: If the Ophint has children, children inputs
|
children_inputs_mappings: If the Ophint has children, children inputs
|
||||||
mappings indicate how their inputs & outputs are mapped.
|
mappings indicate how their inputs & outputs are mapped.
|
||||||
|
@ -700,6 +701,7 @@ class _LiteFuncCall(object):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (inputs, outputs). where input and output i a list of names.
|
Tuple of (inputs, outputs). where input and output i a list of names.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _flatten(input_or_output_dict):
|
def _flatten(input_or_output_dict):
|
||||||
flattened_items = []
|
flattened_items = []
|
||||||
for item in input_or_output_dict.values():
|
for item in input_or_output_dict.values():
|
||||||
|
@ -709,6 +711,7 @@ class _LiteFuncCall(object):
|
||||||
return _flatten(self.inputs), _flatten(self.outputs)
|
return _flatten(self.inputs), _flatten(self.outputs)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|
||||||
def format_args(items):
|
def format_args(items):
|
||||||
s = ""
|
s = ""
|
||||||
for idx, item in items.iteritems():
|
for idx, item in items.iteritems():
|
||||||
|
@ -739,8 +742,8 @@ def _find_all_hints_in_nodes(nodes):
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
attr = node.attr
|
attr = node.attr
|
||||||
# This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
|
# This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
|
||||||
if (OpHint.FUNCTION_UUID_ATTR not in attr
|
if (OpHint.FUNCTION_UUID_ATTR not in attr or
|
||||||
or not attr[OpHint.FUNCTION_UUID_ATTR].s):
|
not attr[OpHint.FUNCTION_UUID_ATTR].s):
|
||||||
continue
|
continue
|
||||||
uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
|
uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
|
||||||
|
|
||||||
|
@ -751,9 +754,11 @@ def _find_all_hints_in_nodes(nodes):
|
||||||
call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
|
call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
|
||||||
# Get sorting and aggregation information
|
# Get sorting and aggregation information
|
||||||
|
|
||||||
sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
|
sort = (
|
||||||
if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
|
attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
|
||||||
if sort == -1: sort = None
|
if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
|
||||||
|
if sort == -1:
|
||||||
|
sort = None
|
||||||
aggregation = None
|
aggregation = None
|
||||||
if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
|
if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
|
||||||
aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
|
aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
|
||||||
|
@ -887,6 +892,7 @@ def _tensor_name_base(full_tensor_name):
|
||||||
Args:
|
Args:
|
||||||
full_tensor_name: A tensor name that is annotated with a device placement
|
full_tensor_name: A tensor name that is annotated with a device placement
|
||||||
(this is what tensor flow introspection gives).
|
(this is what tensor flow introspection gives).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A name without any device assignment.
|
A name without any device assignment.
|
||||||
"""
|
"""
|
||||||
|
@ -919,10 +925,10 @@ def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
|
||||||
while next_to_visit:
|
while next_to_visit:
|
||||||
current_node = next_to_visit.pop()
|
current_node = next_to_visit.pop()
|
||||||
visited.add(current_node)
|
visited.add(current_node)
|
||||||
if (current_node in reachable_by_input
|
if (current_node in reachable_by_input and
|
||||||
and current_node not in input_nodes_set):
|
current_node not in input_nodes_set):
|
||||||
raise TypeError(
|
raise TypeError("Node %s uses input %s not in input_nodes." %
|
||||||
"Node %s uses input %s not in input_nodes." % (n, current_node))
|
(n, current_node))
|
||||||
if current_node not in input_nodes_set:
|
if current_node not in input_nodes_set:
|
||||||
next_to_visit += [
|
next_to_visit += [
|
||||||
input_node for input_node in name_to_input_name[current_node]
|
input_node for input_node in name_to_input_name[current_node]
|
||||||
|
@ -1066,6 +1072,7 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_graph_def: Graph def to use as input.
|
in_graph_def: Graph def to use as input.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Simplified tuple (graph_def, changed_something) where changed_something
|
Simplified tuple (graph_def, changed_something) where changed_something
|
||||||
is true if anything was done.
|
is true if anything was done.
|
||||||
|
@ -1101,15 +1108,15 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
|
||||||
node = name_to_node[current_node_name]
|
node = name_to_node[current_node_name]
|
||||||
is_op_hint_stack = node.name.startswith("OpHintStack")
|
is_op_hint_stack = node.name.startswith("OpHintStack")
|
||||||
is_op_hint_unstack = node.name.startswith("OpHintUnstack")
|
is_op_hint_unstack = node.name.startswith("OpHintUnstack")
|
||||||
if (node.op == "Identity" or is_op_hint_stack
|
if (node.op == "Identity" or is_op_hint_stack or
|
||||||
or (do_generic_pack_unpack and node.op == "Pack")):
|
(do_generic_pack_unpack and node.op == "Pack")):
|
||||||
is_hint_created_stack |= is_op_hint_stack
|
is_hint_created_stack |= is_op_hint_stack
|
||||||
next_to_visit += [
|
next_to_visit += [
|
||||||
input_node for input_node in name_to_input_name[current_node_name]
|
input_node for input_node in name_to_input_name[current_node_name]
|
||||||
if input_node not in visited
|
if input_node not in visited
|
||||||
]
|
]
|
||||||
elif (is_op_hint_unstack
|
elif (is_op_hint_unstack or
|
||||||
or (do_generic_pack_unpack and node.op == "Unpack")):
|
(do_generic_pack_unpack and node.op == "Unpack")):
|
||||||
unpack_nodes.add(node.name)
|
unpack_nodes.add(node.name)
|
||||||
is_hint_created_stack &= is_op_hint_unstack
|
is_hint_created_stack &= is_op_hint_unstack
|
||||||
else:
|
else:
|
||||||
|
@ -1124,7 +1131,8 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
|
||||||
# Unstacked form
|
# Unstacked form
|
||||||
no_external_dependency = True
|
no_external_dependency = True
|
||||||
for other_n in in_graph_def.node:
|
for other_n in in_graph_def.node:
|
||||||
if other_n.name in visited: continue
|
if other_n.name in visited:
|
||||||
|
continue
|
||||||
for input_tensor in name_to_input_name[other_n.name]:
|
for input_tensor in name_to_input_name[other_n.name]:
|
||||||
input_op = _tensor_name_base(input_tensor)
|
input_op = _tensor_name_base(input_tensor)
|
||||||
if input_op in visited and input_op != pack_node:
|
if input_op in visited and input_op != pack_node:
|
||||||
|
@ -1141,9 +1149,9 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
|
||||||
if node_name not in visited:
|
if node_name not in visited:
|
||||||
new_node = _copy.deepcopy(other_n)
|
new_node = _copy.deepcopy(other_n)
|
||||||
new_node.input[:] = [
|
new_node.input[:] = [
|
||||||
(end_input if stripped == pack_node else
|
(end_input if stripped == pack_node else non_stripped)
|
||||||
non_stripped) for stripped, non_stripped in zip(
|
for stripped, non_stripped in zip(name_to_input_name[node_name],
|
||||||
name_to_input_name[node_name], new_node.input[:])
|
new_node.input[:])
|
||||||
]
|
]
|
||||||
out.node.extend([new_node])
|
out.node.extend([new_node])
|
||||||
return out, True
|
return out, True
|
||||||
|
@ -1177,6 +1185,7 @@ def _convert_op_hints_to_stubs_helper(
|
||||||
graph_def: A graph def that we should convert.
|
graph_def: A graph def that we should convert.
|
||||||
write_callback: A function pointer that can be used to write intermediate
|
write_callback: A function pointer that can be used to write intermediate
|
||||||
steps of graph transformation (optional).
|
steps of graph transformation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new stubbed graph_def.
|
A new stubbed graph_def.
|
||||||
"""
|
"""
|
||||||
|
@ -1306,6 +1315,7 @@ def convert_op_hints_to_stubs(session=None,
|
||||||
graph_def: A graph def that we should convert.
|
graph_def: A graph def that we should convert.
|
||||||
write_callback: A function pointer that can be used to write intermediate
|
write_callback: A function pointer that can be used to write intermediate
|
||||||
steps of graph transformation (optional).
|
steps of graph transformation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new graphdef with all ops contained in OpHints being replaced by
|
A new graphdef with all ops contained in OpHints being replaced by
|
||||||
a single op call with the right parameters.
|
a single op call with the right parameters.
|
||||||
|
|
Loading…
Reference in New Issue