Fix style in `op_hint.py` to match formatting from Copybara.

No functional changes

PiperOrigin-RevId: 311566454
Change-Id: Ic4f002df42168bdb8841b80a93ebf22a8e7fa4bd
This commit is contained in:
Mihai Maruseac 2020-05-14 11:02:42 -07:00 committed by TensorFlower Gardener
parent 8098b12009
commit 0d94bc6d71
1 changed files with 36 additions and 26 deletions

View File

@ -435,6 +435,7 @@ class OpHint(object):
Args: Args:
*args: List of inputs to be converted (should be Tf.Tensor). *args: List of inputs to be converted (should be Tf.Tensor).
**kwargs: This allows 'names' which should be a list of names. **kwargs: This allows 'names' which should be a list of names.
Returns: Returns:
Wrapped inputs (identity standins that have additional metadata). These Wrapped inputs (identity standins that have additional metadata). These
are also are also tf.Tensor's. are also are also tf.Tensor's.
@ -453,6 +454,7 @@ class OpHint(object):
Args: Args:
*args: List of outputs to be converted (should be tf.Tensor). *args: List of outputs to be converted (should be tf.Tensor).
**kwargs: See **kwargs: See
Returns: Returns:
Wrapped outputs (identity standins that have additional metadata). These Wrapped outputs (identity standins that have additional metadata). These
are also tf.Tensor's. are also tf.Tensor's.
@ -574,8 +576,8 @@ class _LiteAggregateOperand(_LiteOperand):
elif self.aggregation == OpHint.AGGREGATE_STACK: elif self.aggregation == OpHint.AGGREGATE_STACK:
pass pass
else: else:
raise ValueError( raise ValueError("Invalid aggregation type %r specified" %
"Invalid aggregation type %r specified" % self.aggregation) self.aggregation)
return self.flattened return self.flattened
def flatten(self): def flatten(self):
@ -646,8 +648,8 @@ class _LiteAggregateOperand(_LiteOperand):
stack_node.attr["num"].i = len(flattened) stack_node.attr["num"].i = len(flattened)
output_type = flattened[0].attr["T"].type output_type = flattened[0].attr["T"].type
stack_node.attr["T"].type = output_type stack_node.attr["T"].type = output_type
stack_node.input.append(_tensorflow_output_name( stack_node.input.append(
fused_op_name, output_index)) _tensorflow_output_name(fused_op_name, output_index))
out_graphdef.node.extend([stack_node]) out_graphdef.node.extend([stack_node])
for idx, discrete in enumerate(flattened): for idx, discrete in enumerate(flattened):
@ -675,11 +677,10 @@ class _LiteFuncCall(object):
inputs: inputs to the op (hash from index # to argument) inputs: inputs to the op (hash from index # to argument)
outputs: outputs to the op (hash from index # to argument) outputs: outputs to the op (hash from index # to argument)
function_name: the tflite custom op name to use function_name: the tflite custom op name to use
uuid: a unique call id for this particular call (i.e. uuid: a unique call id for this particular call (i.e. multiple function
multiple function calls would have the same function_name but different calls would have the same function_name but different uuids.
uuids. params: A param name to key value for op constant data. I.e. for axis on a
params: A param name to key value for op constant data. I.e. for reduction, strides on a convolution, etc.
axis on a reduction, strides on a convolution, etc.
level: Level of the OpHint. level: Level of the OpHint.
children_inputs_mappings: If the Ophint has children, children inputs children_inputs_mappings: If the Ophint has children, children inputs
mappings indicate how their inputs & outputs are mapped. mappings indicate how their inputs & outputs are mapped.
@ -700,6 +701,7 @@ class _LiteFuncCall(object):
Returns: Returns:
Tuple of (inputs, outputs). where input and output i a list of names. Tuple of (inputs, outputs). where input and output i a list of names.
""" """
def _flatten(input_or_output_dict): def _flatten(input_or_output_dict):
flattened_items = [] flattened_items = []
for item in input_or_output_dict.values(): for item in input_or_output_dict.values():
@ -709,6 +711,7 @@ class _LiteFuncCall(object):
return _flatten(self.inputs), _flatten(self.outputs) return _flatten(self.inputs), _flatten(self.outputs)
def __str__(self): def __str__(self):
def format_args(items): def format_args(items):
s = "" s = ""
for idx, item in items.iteritems(): for idx, item in items.iteritems():
@ -739,8 +742,8 @@ def _find_all_hints_in_nodes(nodes):
for node in nodes: for node in nodes:
attr = node.attr attr = node.attr
# This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
if (OpHint.FUNCTION_UUID_ATTR not in attr if (OpHint.FUNCTION_UUID_ATTR not in attr or
or not attr[OpHint.FUNCTION_UUID_ATTR].s): not attr[OpHint.FUNCTION_UUID_ATTR].s):
continue continue
uuid = attr[OpHint.FUNCTION_UUID_ATTR].s uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
@ -751,9 +754,11 @@ def _find_all_hints_in_nodes(nodes):
call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
# Get sorting and aggregation information # Get sorting and aggregation information
sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i sort = (
if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
if sort == -1: sort = None if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
if sort == -1:
sort = None
aggregation = None aggregation = None
if OpHint.FUNCTION_AGGREGATE_ATTR in attr: if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
@ -887,6 +892,7 @@ def _tensor_name_base(full_tensor_name):
Args: Args:
full_tensor_name: A tensor name that is annotated with a device placement full_tensor_name: A tensor name that is annotated with a device placement
(this is what tensor flow introspection gives). (this is what tensor flow introspection gives).
Returns: Returns:
A name without any device assignment. A name without any device assignment.
""" """
@ -919,10 +925,10 @@ def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
while next_to_visit: while next_to_visit:
current_node = next_to_visit.pop() current_node = next_to_visit.pop()
visited.add(current_node) visited.add(current_node)
if (current_node in reachable_by_input if (current_node in reachable_by_input and
and current_node not in input_nodes_set): current_node not in input_nodes_set):
raise TypeError( raise TypeError("Node %s uses input %s not in input_nodes." %
"Node %s uses input %s not in input_nodes." % (n, current_node)) (n, current_node))
if current_node not in input_nodes_set: if current_node not in input_nodes_set:
next_to_visit += [ next_to_visit += [
input_node for input_node in name_to_input_name[current_node] input_node for input_node in name_to_input_name[current_node]
@ -1066,6 +1072,7 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
Args: Args:
in_graph_def: Graph def to use as input. in_graph_def: Graph def to use as input.
Returns: Returns:
Simplified tuple (graph_def, changed_something) where changed_something Simplified tuple (graph_def, changed_something) where changed_something
is true if anything was done. is true if anything was done.
@ -1101,15 +1108,15 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
node = name_to_node[current_node_name] node = name_to_node[current_node_name]
is_op_hint_stack = node.name.startswith("OpHintStack") is_op_hint_stack = node.name.startswith("OpHintStack")
is_op_hint_unstack = node.name.startswith("OpHintUnstack") is_op_hint_unstack = node.name.startswith("OpHintUnstack")
if (node.op == "Identity" or is_op_hint_stack if (node.op == "Identity" or is_op_hint_stack or
or (do_generic_pack_unpack and node.op == "Pack")): (do_generic_pack_unpack and node.op == "Pack")):
is_hint_created_stack |= is_op_hint_stack is_hint_created_stack |= is_op_hint_stack
next_to_visit += [ next_to_visit += [
input_node for input_node in name_to_input_name[current_node_name] input_node for input_node in name_to_input_name[current_node_name]
if input_node not in visited if input_node not in visited
] ]
elif (is_op_hint_unstack elif (is_op_hint_unstack or
or (do_generic_pack_unpack and node.op == "Unpack")): (do_generic_pack_unpack and node.op == "Unpack")):
unpack_nodes.add(node.name) unpack_nodes.add(node.name)
is_hint_created_stack &= is_op_hint_unstack is_hint_created_stack &= is_op_hint_unstack
else: else:
@ -1124,7 +1131,8 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
# Unstacked form # Unstacked form
no_external_dependency = True no_external_dependency = True
for other_n in in_graph_def.node: for other_n in in_graph_def.node:
if other_n.name in visited: continue if other_n.name in visited:
continue
for input_tensor in name_to_input_name[other_n.name]: for input_tensor in name_to_input_name[other_n.name]:
input_op = _tensor_name_base(input_tensor) input_op = _tensor_name_base(input_tensor)
if input_op in visited and input_op != pack_node: if input_op in visited and input_op != pack_node:
@ -1141,9 +1149,9 @@ def _remove_one_redundant_stack_unstack(in_graph_def):
if node_name not in visited: if node_name not in visited:
new_node = _copy.deepcopy(other_n) new_node = _copy.deepcopy(other_n)
new_node.input[:] = [ new_node.input[:] = [
(end_input if stripped == pack_node else (end_input if stripped == pack_node else non_stripped)
non_stripped) for stripped, non_stripped in zip( for stripped, non_stripped in zip(name_to_input_name[node_name],
name_to_input_name[node_name], new_node.input[:]) new_node.input[:])
] ]
out.node.extend([new_node]) out.node.extend([new_node])
return out, True return out, True
@ -1177,6 +1185,7 @@ def _convert_op_hints_to_stubs_helper(
graph_def: A graph def that we should convert. graph_def: A graph def that we should convert.
write_callback: A function pointer that can be used to write intermediate write_callback: A function pointer that can be used to write intermediate
steps of graph transformation (optional). steps of graph transformation (optional).
Returns: Returns:
A new stubbed graph_def. A new stubbed graph_def.
""" """
@ -1306,6 +1315,7 @@ def convert_op_hints_to_stubs(session=None,
graph_def: A graph def that we should convert. graph_def: A graph def that we should convert.
write_callback: A function pointer that can be used to write intermediate write_callback: A function pointer that can be used to write intermediate
steps of graph transformation (optional). steps of graph transformation (optional).
Returns: Returns:
A new graphdef with all ops contained in OpHints being replaced by A new graphdef with all ops contained in OpHints being replaced by
a single op call with the right parameters. a single op call with the right parameters.