OpHints supports dynamic_rnn, currently only add support for time_major=True case.

PiperOrigin-RevId: 231709454
This commit is contained in:
A. Unique TensorFlower 2019-01-30 18:38:29 -08:00 committed by TensorFlower Gardener
parent 7d8bfd88cd
commit 82ede0271e
9 changed files with 580 additions and 73 deletions

View File

@ -44,7 +44,7 @@ py_test(
"//tensorflow:tensorflow_py",
"//tensorflow/examples/tutorials/mnist:input_data",
"//tensorflow/lite/python:lite",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:framework",
"//tensorflow/python:platform",
"//tensorflow/python/tools:optimize_for_inference",
"//third_party/py/numpy",

View File

@ -22,17 +22,27 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.lite.python import lite
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.rnn import _best_effort_input_batch_size
from tensorflow.python.ops.rnn import _dynamic_rnn_loop
from tensorflow.python.ops.rnn import _should_cache
from tensorflow.python.ops.rnn import _transpose_batch_time
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
@ -394,3 +404,240 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
}
base_config = super(TFLiteLSTMCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def dynamic_rnn(cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=True,
scope=None):
"""Creates a recurrent neural network specified by RNNCell `cell`.
Performs fully dynamic unrolling of `inputs`.
Example:
```python
# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
```
```python
# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=data,
dtype=tf.float32)
```
Args:
cell: An instance of RNNCell.
inputs: The RNN inputs.
If `time_major == False` (default), this must be a `Tensor` of shape:
`[batch_size, max_time, ...]`, or a nested tuple of such elements.
If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
batch_size, ...]`, or a nested tuple of such elements. This may also be
a (possibly nested) tuple of Tensors satisfying this property. The
first two dimensions must match across all the inputs, but otherwise the
ranks and other shape components may differ. In this case, input to
`cell` at each time-step will replicate the structure of these tuples,
except for the time dimension (from which the time is taken). The input
to `cell` at each time step will be a `Tensor` or (possibly nested)
tuple of Tensors each with dimensions `[batch_size, ...]`.
sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
to copy-through state and zero-out outputs when past a batch element's
sequence length. So it's more for performance than correctness.
initial_state: (optional) An initial state for the RNN. If `cell.state_size`
is an integer, this must be a `Tensor` of appropriate type and shape
`[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
should be a tuple of tensors having shapes `[batch_size, s] for s in
cell.state_size`.
dtype: (optional) The data type for the initial state and expected output.
Required if initial_state is not provided or RNN state has a heterogeneous
dtype.
parallel_iterations: (Default: 32). The number of iterations to run in
parallel. Those operations which do not have any temporal dependency and
can be run in parallel, will be. This parameter trades off time for
space. Values >> 1 use more memory but take less time, while smaller
values use less memory but computations take longer.
swap_memory: Transparently swap the tensors produced in forward inference
but needed for back prop from GPU to CPU. This allows training RNNs which
would typically not fit on a single GPU, with very minimal (or no)
performance penalty.
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
`time_major = True` is a bit more efficient because it avoids transposes
at the beginning and end of the RNN calculation. However, most TensorFlow
data is batch-major, so by default this function accepts input and emits
output in batch-major form.
scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A pair (outputs, state) where:
outputs: The RNN output `Tensor`.
If time_major == False (default), this will be a `Tensor` shaped:
`[batch_size, max_time, cell.output_size]`.
If time_major == True, this will be a `Tensor` shaped:
`[max_time, batch_size, cell.output_size]`.
Note, if `cell.output_size` is a (possibly nested) tuple of integers
or `TensorShape` objects, then `outputs` will be a tuple having the
same structure as `cell.output_size`, containing Tensors having shapes
corresponding to the shape data in `cell.output_size`.
state: The final state. If `cell.state_size` is an int, this
will be shaped `[batch_size, cell.state_size]`. If it is a
`TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
If it is a (possibly nested) tuple of ints or `TensorShape`, this will
be a tuple having the corresponding shapes. If cells are `LSTMCells`
`state` will be a tuple containing a `LSTMStateTuple` for each cell.
Raises:
TypeError: If `cell` is not an instance of RNNCell.
ValueError: If inputs is None or an empty list.
RuntimeError: If not using control flow v2.
"""
# Currently only support time_major == True case.
assert time_major
# TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or
# TfLiteRNNCells.
rnn_cell_impl.assert_like_rnncell("cell", cell)
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
raise RuntimeError("OpHint dynamic rnn only supports control flow v2.")
parent_first_child_input = [{
"parent_ophint_input_index": 0,
"first_child_ophint_input_index": 0
}]
parent_last_child_output = [{
"parent_output_index": 0,
# For LstmCell, the index is 2.
# For RnnCell, the index is 1.
# So we use -1 meaning it's the last one.
"child_output_index": -1
}]
internal_children_input_output = [{
"child_input_index": 0,
# For LstmCell, the index is 2.
# For RnnCell, the index is 1.
# So we use -1 meaning it's the last one.
"child_output_index": -1
}]
inputs_outputs_mappings = {
"parent_first_child_input": parent_first_child_input,
"parent_last_child_output": parent_last_child_output,
"internal_children_input_output": internal_children_input_output
}
tflite_wrapper = lite.OpHint(
"TfLiteDynamicRnn",
level=2,
children_inputs_mappings=inputs_outputs_mappings)
with vs.variable_scope(scope or "rnn") as varscope:
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0)
# By default, time_major==False and inputs are batch-major: shaped
# [batch, time, depth]
# For internal calculations, we transpose to [time, batch, depth]
flat_input = nest.flatten(inputs)
if not time_major:
# (batch, time, depth) => (time, batch, depth)
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
parallel_iterations = parallel_iterations or 32
if sequence_length is not None:
sequence_length = math_ops.to_int32(sequence_length)
if sequence_length.get_shape().rank not in (None, 1):
raise ValueError(
"sequence_length must be a vector of length batch_size, "
"but saw shape: %s" % sequence_length.get_shape())
sequence_length = array_ops.identity( # Just to find it in the graph.
sequence_length,
name="sequence_length")
batch_size = _best_effort_input_batch_size(flat_input)
if initial_state is not None:
state = initial_state
else:
if not dtype:
raise ValueError("If there is no initial_state, you must give a dtype.")
if getattr(cell, "get_initial_state", None) is not None:
state = cell.get_initial_state(
inputs=None, batch_size=batch_size, dtype=dtype)
else:
state = cell.zero_state(batch_size, dtype)
def _assert_has_shape(x, shape):
x_shape = array_ops.shape(x)
packed_shape = array_ops.stack(shape)
return control_flow_ops.Assert(
math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
"Expected shape for Tensor %s is " % x.name, packed_shape,
" but saw shape: ", x_shape
])
if not context.executing_eagerly() and sequence_length is not None:
# Perform some shape validation
with ops.control_dependencies(
[_assert_has_shape(sequence_length, [batch_size])]):
sequence_length = array_ops.identity(
sequence_length, name="CheckSeqLen")
inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
outputs, final_state = _dynamic_rnn_loop(
cell,
inputs,
state,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
sequence_length=sequence_length,
dtype=dtype)
# Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
# If we are performing batch-major calculations, transpose output back
# to shape [batch, time, depth]
if not time_major:
# (time, batch, depth) => (batch, time, depth)
outputs = nest.map_structure(_transpose_batch_time, outputs)
outputs = tflite_wrapper.add_output(outputs, name="outputs")
return outputs, final_state

View File

@ -20,12 +20,14 @@ import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.lite.experimental.examples.lstm.tflite_lstm import dynamic_rnn
from tensorflow.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.tools import optimize_for_inference_lib
# Number of steps to train model.
TRAIN_STEPS = 1
@ -67,7 +69,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
TFLiteLSTMCell(self.num_units, forget_bias=0, name="rnn4")
])
def buildModel(self, lstm_layer, is_dynamic_rnn, is_train):
def buildModel(self, lstm_layer, is_dynamic_rnn):
# Weights and biases for output softmax layer.
out_weights = tf.Variable(
tf.random_normal([self.num_units, self.n_classes]))
@ -77,16 +79,11 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
x = tf.placeholder(
"float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE")
# For dynamic_rnn, train with dynamic_rnn and inference with static_rnn.
# x is shaped [batch_size,time_steps,num_inputs]
if is_dynamic_rnn:
if is_train:
lstm_input = x
outputs, _ = tf.nn.dynamic_rnn(lstm_layer, lstm_input, dtype="float32")
outputs = tf.unstack(outputs, axis=1)
else:
lstm_input = tf.unstack(x, self.time_steps, 1)
outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32")
lstm_input = tf.transpose(x, perm=[1, 0, 2])
outputs, _ = dynamic_rnn(lstm_layer, lstm_input, dtype="float32")
outputs = tf.unstack(outputs, axis=0)
else:
lstm_input = tf.unstack(x, self.time_steps, 1)
outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32")
@ -126,8 +123,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
# Reset the graph.
tf.reset_default_graph()
x, prediction, output_class = self.buildModel(
lstm_layer, is_dynamic_rnn, is_train=False)
x, prediction, output_class = self.buildModel(lstm_layer, is_dynamic_rnn)
new_sess = tf.Session(config=CONFIG)
saver = tf.train.Saver()
@ -159,6 +155,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
tflite = tf.lite.toco_convert(
curr, [tflite_input], [outputs], allow_custom_ops=False)
interpreter = tf.lite.Interpreter(model_content=tflite)
try:
@ -179,7 +176,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
sess = tf.Session(config=CONFIG)
x, prediction, output_class = self.buildModel(
self.buildLstmLayer(), is_dynamic_rnn=False, is_train=True)
self.buildLstmLayer(), is_dynamic_rnn=False)
self.trainModel(x, prediction, output_class, sess)
saver = tf.train.Saver()
@ -192,26 +189,15 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
result = self.tfliteInvoke(frozen_graph, test_inputs, output_class)
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
@test_util.enable_control_flow_v2
def testDynamicRnnMultiRnnCell(self):
sess = tf.Session(config=CONFIG)
x, prediction, output_class = self.buildModel(
self.buildLstmLayer(), is_dynamic_rnn=True, is_train=True)
self.buildLstmLayer(), is_dynamic_rnn=True)
self.trainModel(x, prediction, output_class, sess)
# Since we don't yet support OpHints for dynamic, we will load the model
# back in as a static model. This requires the variables to have the same
# names as if they were trained as a static. Thus, we get rid of while/rnn
# names.
variables_to_save = {}
for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
op_name = i.name
if op_name.startswith("while/rnn/"):
op_name = op_name.split("while/rnn/")[1]
if op_name.endswith(":0"):
op_name = op_name.split(":0")[0]
variables_to_save[op_name] = i
saver = tf.train.Saver(variables_to_save)
saver = tf.train.Saver()
x, prediction, output_class, new_sess = self.saveAndRestoreModel(
self.buildLstmLayer(), sess, saver, is_dynamic_rnn=True)

View File

@ -71,6 +71,7 @@ from __future__ import print_function
import collections as _collections
import copy as _copy
import json as _json
import uuid as _uuid
import six as _six
@ -132,6 +133,14 @@ class OpHint(object):
# "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
# attribute to [2, 0, 1, -1].
TFLITE_INPUT_INDICES = "_tflite_input_indices"
# OpHint level.
FUNCTION_LEVEL_ATTR = "_tflite_ophint_level"
# Ophint internal mapping, this is for high level Ophint only.
# This basically contains three kinds of mapping:
# 1) How parental ophinted inputs map to the first child ophinted inputs;
# 2) How internal children nodes are connected;
# 3) How parental ophinted outputs map to the last child ophinted outputs.
CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping"
# Types of aggregations
# stack: stacks all ophints with matching tags. i.e. for a static rnn.
@ -149,10 +158,16 @@ class OpHint(object):
"""Conceptually tracks indices of arguments of "OpHint functions".
The inputs and arguments of these functions both use an instance
of the class so they can have independent numbering."""
of the class so they can have independent numbering.
"""
def __init__(self, function_name, unique_function_id, node_name_prefix,
attr_name):
def __init__(self,
function_name,
unique_function_id,
node_name_prefix,
attr_name,
level=1,
children_inputs_mappings=None):
"""Initialize ophint argument.
Args:
@ -161,6 +176,8 @@ class OpHint(object):
node_name_prefix: How identities that are created are named.
attr_name: Name of attribute to use to store the index for this hint.
i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
level: Hierarchical level of the Ophint node, a number.
children_inputs_mappings: Inputs/Outputs mapping for children hints.
"""
# The global index is the argument index of the op. This is in contrast
@ -176,6 +193,8 @@ class OpHint(object):
self._tag_to_next_sort_index = {} # The current index for each tag
self._node_name_prefix = node_name_prefix
self._attr_name = attr_name
self._level = level
self._children_inputs_mappings = children_inputs_mappings
def _get_new_global_index(self, index_override):
"""Return the next unused argument index in order or use an override.
@ -251,6 +270,7 @@ class OpHint(object):
uuid = self._unique_function_id
name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
uuid, global_index, sort_index, name)
identity_op = _array_ops.identity(arg, name=name)
# pylint: disable=protected-access
@ -264,6 +284,15 @@ class OpHint(object):
s=_compat.as_bytes(self._unique_function_id)))
identity_op.op._set_attr(
self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR,
_attr_value_pb2.AttrValue(i=self._level))
if self._children_inputs_mappings:
identity_op.op._set_attr(
OpHint.CHILDREN_INPUTS_MAPPINGS,
_attr_value_pb2.AttrValue(
s=_compat.as_bytes(_json.dumps(
self._children_inputs_mappings))))
if sort_index is not None:
identity_op.op._set_attr(
OpHint.FUNCTION_SORT_INDEX_ATTR,
@ -275,23 +304,74 @@ class OpHint(object):
# pylint: enable=protected-access
return identity_op
def __init__(self, function_name, **kwargs):
def __init__(self,
function_name,
level=1,
children_inputs_mappings=None,
**kwargs):
"""Create a OpHint.
Args:
function_name: Name of the function (the custom op name in tflite)
level: OpHint level.
children_inputs_mappings: Children OpHint inputs/outputs mapping.
children_inputs_mappings should like below:
"parent_first_child_input":
[{"parent_input_index": num, "child_input_index": num}, ...]
"parent_last_child_output":
[{"parent_output_index": num, "child_output_index": num}, ...]
"internal_children_input_output":
[{"child_input_index": num, "child_output_index": num}, ...]
**kwargs: Keyword arguments of any constant attributes for the function.
"""
self._function_name = function_name
self._level = level
if self._level == 1:
assert children_inputs_mappings is None
else:
assert isinstance(children_inputs_mappings, dict)
self._children_inputs_mappings = children_inputs_mappings
if self._children_inputs_mappings is not None:
self._validate_children_inputs_mappings(self._children_inputs_mappings)
self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough?
self._attrs_to_store_later = kwargs
self._stored_attrs = False
self._inputs = OpHint.OpHintArgumentTracker(
self._function_name, self._unique_function_id, "InputHint",
OpHint.FUNCTION_INPUT_INDEX_ATTR)
OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings)
self._outputs = OpHint.OpHintArgumentTracker(
self._function_name, self._unique_function_id, "OutputHint",
OpHint.FUNCTION_OUTPUT_INDEX_ATTR)
OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level,
self._children_inputs_mappings)
def _validate_children_inputs_mappings(self, children_inputs_mappings):
"""Validate children inputs mappings is in the right format.
Args:
children_inputs_mappings: the Children ophint inputs/outputs mapping.
"""
assert isinstance(children_inputs_mappings, dict)
assert "parent_first_child_input" in children_inputs_mappings
assert "parent_last_child_output" in children_inputs_mappings
assert "internal_children_input_output" in children_inputs_mappings
# validate parent_first_child_input.
def assert_dictlist_has_keys(dictlist, keys):
for dikt in dictlist:
assert isinstance(dikt, dict)
for key in keys:
assert key in dikt
assert_dictlist_has_keys(
children_inputs_mappings["parent_first_child_input"],
["parent_ophint_input_index", "first_child_ophint_input_index"])
assert_dictlist_has_keys(
children_inputs_mappings["parent_last_child_output"],
["parent_output_index", "child_output_index"])
assert_dictlist_has_keys(
children_inputs_mappings["internal_children_input_output"],
["child_input_index", "child_output_index"])
def _setattr(self, dest_op, name, value):
tensor_value = _ops.convert_to_tensor(value)
@ -382,7 +462,7 @@ class OpHint(object):
class _LiteOperand(object):
"""Abstract operand for a tflite hint function.
"""Abstract operand for a tflite hint function._dynamic_rnn_loop.
This is a base class that handles representing arguments to an OpHint.
It also is able to serialize operands to the stubbed graph_def.
@ -580,15 +660,18 @@ class _LiteFuncCall(object):
This is uses to accumulate found hints in the graphdef into a single
conceptual unit.
Properties:
self.inputs: inputs to the op (hash from index # to argument)
self.outputs: outputs to the op (hash from index # to argument)
self.function_name: the tflite custom op name to use
self.uuid: a unique call id for this particular call (i.e.
Attributes:
inputs: inputs to the op (hash from index # to argument)
outputs: outputs to the op (hash from index # to argument)
function_name: the tflite custom op name to use
uuid: a unique call id for this particular call (i.e.
multiple function calls would have the same function_name but different
uuids.
self.params: A param name to key value for op constant data. I.e. for
params: A param name to key value for op constant data. I.e. for
axis on a reduction, strides on a convolution, etc.
level: Level of the OpHint.
children_inputs_mappings: If the Ophint has children, children inputs
mappings indicate how their inputs & outputs are mapped.
"""
def __init__(self):
@ -597,6 +680,8 @@ class _LiteFuncCall(object):
self.function_name = None
self.uuid = None
self.params = {}
self.level = -1
self.children_inputs_mappings = {}
def flattened_inputs_and_outputs(self):
"""Return a list of inputs and outputs in a flattened format.
@ -622,22 +707,25 @@ class _LiteFuncCall(object):
inputs_str = "\tInputs\n" + format_args(self.inputs)
outputs_str = "\tOutputs\n" + format_args(self.outputs)
return ("tflite function %s call %s\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s"
% (self.function_name, self.uuid, inputs_str, outputs_str))
return (
"tflite function %s call %s level %d "
"\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" %
(self.function_name, self.uuid, self.level, inputs_str, outputs_str))
def _find_all_hints_in_graph_def(graphdef):
"""Look at the current default graph and return a list of LiteFuncCall objs.
def _find_all_hints_in_nodes(nodes):
"""Look at the all the input nodes and return a list of LiteFuncCall objs.
Args:
graphdef: A TensorFlow graph_def to look for LiteFuncCalls.
nodes: A TensorFlow graph_def to look for LiteFuncCalls.
Returns:
a list of `LifeFuncCall` objects in the form
"""
func_calls = _collections.defaultdict(_LiteFuncCall)
for node in graphdef.node:
for node in nodes:
attr = node.attr
# This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
@ -649,6 +737,7 @@ def _find_all_hints_in_graph_def(graphdef):
call_def = func_calls[uuid]
call_def.uuid = uuid
call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
# Get sorting and aggregation information
sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
@ -658,6 +747,10 @@ def _find_all_hints_in_graph_def(graphdef):
if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
if OpHint.CHILDREN_INPUTS_MAPPINGS in attr:
call_def.children_inputs_mappings = _json.loads(
_compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s))
# Add the input or output
def put_operand(stuff, index, sort, operand, aggregation):
"""Add a given index into the function structure."""
@ -683,6 +776,98 @@ def _find_all_hints_in_graph_def(graphdef):
return func_calls
def _extract_topology_sequence_mapping(nodes):
return dict(
(_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes))
def _find_children_hints_in_while_loop(function_def, nodes_mapping):
"""Find children hints and all nodes inside the while loop.
Args:
function_def: Function def of the while loop.
nodes_mapping: While loop input_arg : real node name.
Returns:
Ordered children hints and all re-mapped nodes inside the while loop.
"""
new_nodes = []
# Make nodes inside function def inputs point to the real nodes.
for node in function_def.node_def:
for i in range(len(node.input)):
if node.input[i] in nodes_mapping:
node.input[i] = nodes_mapping[node.input[i]]
new_nodes.append(_copy.deepcopy(node))
name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def)
children_hints = _find_all_hints_in_nodes(new_nodes)
children_hints_q = []
# Ordered by the outputs.
for hint in _six.itervalues(children_hints):
_, output_names = hint.flattened_inputs_and_outputs()
seq = name_to_seq_num[output_names[0]]
for output_name in output_names:
seq = min(seq, name_to_seq_num[output_name])
children_hints_q.append((seq, hint))
children_hints_q.sort(key=lambda tup: tup[0])
ordered_children_hints = [x[1] for x in children_hints_q]
return ordered_children_hints, new_nodes
def _find_children_hints(call, graph_def):
"""Find all children hints.
For a given OpHint, we find all children hints inside it, we also copy all the
nodes inside function defs (if applicable) to the original graph_def, they are
returned in a list as well.
Args:
call: Parent OpHint that contains children ophints.
graph_def: Original graph def.
Returns:
Ordered children hints inside the parent ophint; new graph def that contains
nodes inside function defs (if applicable); nodes inside function defs.
"""
name_to_input_name, _, _ = _extract_graph_summary(graph_def)
input_names, output_names = call.flattened_inputs_and_outputs()
reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
reachable_by_output = _bfs_for_reachable_nodes(output_names,
name_to_input_name)
output_nodes_set = set(output_names)
children_hints = []
out = _graph_pb2.GraphDef()
out.library.CopyFrom(graph_def.library)
out.versions.CopyFrom(graph_def.versions)
function_def_nodes = set()
for node in graph_def.node:
out.node.extend([_copy.deepcopy(node)])
n = _tensor_name_base(node.name)
if n in reachable_by_output:
if n not in reachable_by_input and n not in output_nodes_set:
# special handle for while loop function def.
if node.op == "While":
body_name = node.attr["body"].func.name
inputs_outside_loop = node.input
for function_def in graph_def.library.function:
if function_def.signature.name == body_name:
function_inputs = function_def.signature.input_arg
assert len(inputs_outside_loop) == len(function_inputs)
nodes_mapping = {}
for i in range(len(function_inputs)):
nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i]
# TODO(b/123050804): Consider use grappler.
(children_hints_in_loop,
new_nodes) = _find_children_hints_in_while_loop(
function_def, nodes_mapping)
function_def_nodes.update([x.name for x in new_nodes])
children_hints.extend(children_hints_in_loop)
out.node.extend(new_nodes)
return children_hints, out, function_def_nodes
def _tensor_name_base(full_tensor_name):
"""Removes the device assignment code from a tensor.
@ -735,12 +920,20 @@ def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
# TODO(aselle): This should be converted to grappler in the future.
def _convert_single_op_hint_to_stub(call, graph_def):
def _convert_single_op_hint_to_stub(call,
graph_def,
function_def_nodes=None,
is_last_run=True):
"""Given a graph_def, converts `call` into a stub and returns a new graph_def.
Args:
call: A single function call to be converted.
graph_def: A graph_def to use as input (that hass call obviously).
function_def_nodes: Nodes inside the function def those are not connected to
the graph.
is_last_run: Whether it is the last run for a given pass (for OpHint has
children).
Returns:
A new transformed graph-def that has call as a stub (single op).
@ -748,6 +941,8 @@ def _convert_single_op_hint_to_stub(call, graph_def):
the tensorflow runtime, so all future manipulations are done in graph_def
level.
"""
if function_def_nodes is None:
function_def_nodes = set()
name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
graph_def)
input_names, output_names = call.flattened_inputs_and_outputs()
@ -755,7 +950,6 @@ def _convert_single_op_hint_to_stub(call, graph_def):
reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
reachable_by_output = _bfs_for_reachable_nodes(output_names,
name_to_input_name)
input_nodes_set = set(input_names)
output_nodes_set = set(output_names)
nodes_after_fuse = []
nodes_deleted_by_fuse = set()
@ -766,19 +960,16 @@ def _convert_single_op_hint_to_stub(call, graph_def):
n = _tensor_name_base(node.name)
if n in reachable_by_output:
if n not in reachable_by_input and n not in output_nodes_set:
# n is an internal node. Check to make sure it is really internal.
# TODO(aselle): this could be done more efficiently by flooding
# the graph first.
_check_subgraph_closed(n, reachable_by_input, input_nodes_set,
name_to_input_name)
nodes_deleted_by_fuse.add(n)
elif n not in reachable_by_input:
elif n not in reachable_by_input and n not in function_def_nodes:
# n is a node that after all the fusings, so keep it.
nodes_after_fuse.append(n)
else:
# n is a node that is randomly in the graph but not connected to
# the chain of dependencies.
pass
# In the last run, n is a node that is randomly in the graph but not
# connected to the chain of dependencies, we will delete n, otherwise
# we keep them.
if not is_last_run:
nodes_after_fuse.append(n)
# Make a new graphdef with all the pre-input and input nodes
out = _graph_pb2.GraphDef()
@ -800,7 +991,8 @@ def _convert_single_op_hint_to_stub(call, graph_def):
# non-fused things.
for input_index in sorted_input_indices:
inputs = call.inputs[input_index]
new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
input_name = inputs.aggregate_and_return_name_for_input(out)
new_node.input.append(input_name)
new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
# Ceate the function
@ -936,6 +1128,18 @@ def _remove_redundant_stack_unstack(graph_def):
return curr
def _get_correct_mapping(original_index, nodes):
# Special handle for the index is -1 case.
# If it is -1, return the last index.
if original_index == -1:
node_indices = nodes.keys()
node_indices.sort()
return node_indices[-1]
else:
return original_index
return original_index
@_tf_export("lite.convert_op_hints_to_stubs")
def _convert_op_hints_to_stubs_helper(
graph_def, write_callback=lambda sess, graph_def: None):
@ -948,14 +1152,67 @@ def _convert_op_hints_to_stubs_helper(
Returns:
A new stubbed graph_def.
"""
hints = _find_all_hints_in_nodes(graph_def.node)
hints_q = []
for hint in _six.itervalues(hints):
hints_q.append((hint.level, hint.uuid))
hints_q.sort(key=lambda tup: tup[0])
for i in range(len(hints_q) - 1, -1, -1):
level, hint_uuid = hints_q[i]
hints = _find_all_hints_in_graph_def(graph_def)
curr_graph_def = graph_def
del graph_def # prevent using graph_def again (common source of error)
for hint in _six.itervalues(hints):
curr_graph_def = _convert_single_op_hint_to_stub(
hint, curr_graph_def)
write_callback(curr_graph_def, "initial")
for i in range(len(hints_q) - 1, -1, -1):
level, hint_uuid = hints_q[i]
if level >= 2:
children_hints, curr_graph_def, function_def_nodes = _find_children_hints(
hints[hint_uuid], curr_graph_def)
# pylint: disable=superfluous-parens
assert (len(children_hints) > 0) # pylint: disable=g-explicit-length-test
# pylint: enable=superfluous-parens
# Re-wire the children hints inputs/outputs, so latter child's inputs
# connect to previous child node's outputs.
children_inputs_mappings = hints[hint_uuid].children_inputs_mappings
for j in range(len(children_hints)):
child_hint = children_hints[j]
if j == 0:
for mapping in children_inputs_mappings["parent_first_child_input"]:
parent_input_index = _get_correct_mapping(
mapping["parent_ophint_input_index"], hints[hint_uuid].inputs)
child_input_index = _get_correct_mapping(
mapping["first_child_ophint_input_index"], child_hint.inputs)
child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[
parent_input_index]
else:
for mapping in children_inputs_mappings[
"internal_children_input_output"]:
input_index = _get_correct_mapping(mapping["child_input_index"],
child_hint.inputs)
output_index = _get_correct_mapping(mapping["child_output_index"],
children_hints[j - 1].outputs)
child_hint.inputs[input_index] = children_hints[
j - 1].outputs[output_index]
if j == len(children_hints) - 1:
for mapping in children_inputs_mappings["parent_last_child_output"]:
parent_output_index = _get_correct_mapping(
mapping["parent_output_index"], hints[hint_uuid].outputs)
child_output_index = _get_correct_mapping(
mapping["child_output_index"], child_hint.outputs)
child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[
parent_output_index]
for j in range(len(children_hints)):
child_hint = children_hints[j]
curr_graph_def = _convert_single_op_hint_to_stub(
child_hint, curr_graph_def, function_def_nodes,
j == len(children_hints) - 1)
else:
curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid],
curr_graph_def)
write_callback(curr_graph_def, "initial")
# The stubbing process can create stacks/unstacks in the case of LSTMs
# remove them.
curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
@ -982,9 +1239,9 @@ def find_all_hinted_output_nodes(session=None, graph_def=None):
raise ValueError("Provide only one of session and graph_def.")
hinted_outputs_nodes = []
if session is not None:
hints = _find_all_hints_in_graph_def(session.graph_def)
hints = _find_all_hints_in_nodes(session.graph_def.node)
elif graph_def is not None:
hints = _find_all_hints_in_graph_def(graph_def)
hints = _find_all_hints_in_nodes(graph_def.node)
for hint in _six.itervalues(hints):
_, ouput_nodes = hint.flattened_inputs_and_outputs()
hinted_outputs_nodes.extend(ouput_nodes)

View File

@ -143,13 +143,14 @@ def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
# Breadth first search to find all the nodes that we should keep.
next_to_visit = target_nodes[:]
while next_to_visit:
n = next_to_visit[0]
node = next_to_visit[0]
del next_to_visit[0]
if n in nodes_to_keep:
if node in nodes_to_keep:
# Already visited this node.
continue
nodes_to_keep.add(n)
next_to_visit += name_to_input_name[n]
nodes_to_keep.add(node)
if node in name_to_input_name:
next_to_visit += name_to_input_name[node]
return nodes_to_keep

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "add"

View File

@ -14,6 +14,10 @@ tf_class {
name: "AGGREGATE_STACK"
mtype: "<type \'str\'>"
}
member {
name: "CHILDREN_INPUTS_MAPPINGS"
mtype: "<type \'str\'>"
}
member {
name: "FUNCTION_AGGREGATE_ATTR"
mtype: "<type \'str\'>"
@ -22,6 +26,10 @@ tf_class {
name: "FUNCTION_INPUT_INDEX_ATTR"
mtype: "<type \'str\'>"
}
member {
name: "FUNCTION_LEVEL_ATTR"
mtype: "<type \'str\'>"
}
member {
name: "FUNCTION_NAME_ATTR"
mtype: "<type \'str\'>"
@ -48,7 +56,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'function_name\'], varargs=None, keywords=kwargs, defaults=None"
argspec: "args=[\'self\', \'function_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\'], "
}
member_method {
name: "add_input"

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "add"

View File

@ -14,6 +14,10 @@ tf_class {
name: "AGGREGATE_STACK"
mtype: "<type \'str\'>"
}
member {
name: "CHILDREN_INPUTS_MAPPINGS"
mtype: "<type \'str\'>"
}
member {
name: "FUNCTION_AGGREGATE_ATTR"
mtype: "<type \'str\'>"
@ -22,6 +26,10 @@ tf_class {
name: "FUNCTION_INPUT_INDEX_ATTR"
mtype: "<type \'str\'>"
}
member {
name: "FUNCTION_LEVEL_ATTR"
mtype: "<type \'str\'>"
}
member {
name: "FUNCTION_NAME_ATTR"
mtype: "<type \'str\'>"
@ -48,7 +56,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'function_name\'], varargs=None, keywords=kwargs, defaults=None"
argspec: "args=[\'self\', \'function_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\'], "
}
member_method {
name: "add_input"