OpHints supports dynamic_rnn, currently only add support for time_major=True case.
PiperOrigin-RevId: 231709454
This commit is contained in:
parent
7d8bfd88cd
commit
82ede0271e
@ -44,7 +44,7 @@ py_test(
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/examples/tutorials/mnist:input_data",
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python/tools:optimize_for_inference",
|
||||
"//third_party/py/numpy",
|
||||
|
@ -22,17 +22,27 @@ from __future__ import print_function
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import activations
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.layers import base as base_layer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops.rnn import _best_effort_input_batch_size
|
||||
from tensorflow.python.ops.rnn import _dynamic_rnn_loop
|
||||
from tensorflow.python.ops.rnn import _should_cache
|
||||
from tensorflow.python.ops.rnn import _transpose_batch_time
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
|
||||
@ -394,3 +404,240 @@ class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
|
||||
}
|
||||
base_config = super(TFLiteLSTMCell, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
def dynamic_rnn(cell,
|
||||
inputs,
|
||||
sequence_length=None,
|
||||
initial_state=None,
|
||||
dtype=None,
|
||||
parallel_iterations=None,
|
||||
swap_memory=False,
|
||||
time_major=True,
|
||||
scope=None):
|
||||
"""Creates a recurrent neural network specified by RNNCell `cell`.
|
||||
|
||||
Performs fully dynamic unrolling of `inputs`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# create a BasicRNNCell
|
||||
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
|
||||
|
||||
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
|
||||
|
||||
# defining initial state
|
||||
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
|
||||
|
||||
# 'state' is a tensor of shape [batch_size, cell_state_size]
|
||||
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
|
||||
initial_state=initial_state,
|
||||
dtype=tf.float32)
|
||||
```
|
||||
|
||||
```python
|
||||
# create 2 LSTMCells
|
||||
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
|
||||
|
||||
# create a RNN cell composed sequentially of a number of RNNCells
|
||||
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
|
||||
|
||||
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
|
||||
# 'state' is a N-tuple where N is the number of LSTMCells containing a
|
||||
# tf.contrib.rnn.LSTMStateTuple for each cell
|
||||
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
|
||||
inputs=data,
|
||||
dtype=tf.float32)
|
||||
```
|
||||
|
||||
|
||||
Args:
|
||||
cell: An instance of RNNCell.
|
||||
inputs: The RNN inputs.
|
||||
If `time_major == False` (default), this must be a `Tensor` of shape:
|
||||
`[batch_size, max_time, ...]`, or a nested tuple of such elements.
|
||||
If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
|
||||
batch_size, ...]`, or a nested tuple of such elements. This may also be
|
||||
a (possibly nested) tuple of Tensors satisfying this property. The
|
||||
first two dimensions must match across all the inputs, but otherwise the
|
||||
ranks and other shape components may differ. In this case, input to
|
||||
`cell` at each time-step will replicate the structure of these tuples,
|
||||
except for the time dimension (from which the time is taken). The input
|
||||
to `cell` at each time step will be a `Tensor` or (possibly nested)
|
||||
tuple of Tensors each with dimensions `[batch_size, ...]`.
|
||||
sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
|
||||
to copy-through state and zero-out outputs when past a batch element's
|
||||
sequence length. So it's more for performance than correctness.
|
||||
initial_state: (optional) An initial state for the RNN. If `cell.state_size`
|
||||
is an integer, this must be a `Tensor` of appropriate type and shape
|
||||
`[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
|
||||
should be a tuple of tensors having shapes `[batch_size, s] for s in
|
||||
cell.state_size`.
|
||||
dtype: (optional) The data type for the initial state and expected output.
|
||||
Required if initial_state is not provided or RNN state has a heterogeneous
|
||||
dtype.
|
||||
parallel_iterations: (Default: 32). The number of iterations to run in
|
||||
parallel. Those operations which do not have any temporal dependency and
|
||||
can be run in parallel, will be. This parameter trades off time for
|
||||
space. Values >> 1 use more memory but take less time, while smaller
|
||||
values use less memory but computations take longer.
|
||||
swap_memory: Transparently swap the tensors produced in forward inference
|
||||
but needed for back prop from GPU to CPU. This allows training RNNs which
|
||||
would typically not fit on a single GPU, with very minimal (or no)
|
||||
performance penalty.
|
||||
time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
|
||||
these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
|
||||
these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
|
||||
`time_major = True` is a bit more efficient because it avoids transposes
|
||||
at the beginning and end of the RNN calculation. However, most TensorFlow
|
||||
data is batch-major, so by default this function accepts input and emits
|
||||
output in batch-major form.
|
||||
scope: VariableScope for the created subgraph; defaults to "rnn".
|
||||
|
||||
Returns:
|
||||
A pair (outputs, state) where:
|
||||
|
||||
outputs: The RNN output `Tensor`.
|
||||
|
||||
If time_major == False (default), this will be a `Tensor` shaped:
|
||||
`[batch_size, max_time, cell.output_size]`.
|
||||
|
||||
If time_major == True, this will be a `Tensor` shaped:
|
||||
`[max_time, batch_size, cell.output_size]`.
|
||||
|
||||
Note, if `cell.output_size` is a (possibly nested) tuple of integers
|
||||
or `TensorShape` objects, then `outputs` will be a tuple having the
|
||||
same structure as `cell.output_size`, containing Tensors having shapes
|
||||
corresponding to the shape data in `cell.output_size`.
|
||||
|
||||
state: The final state. If `cell.state_size` is an int, this
|
||||
will be shaped `[batch_size, cell.state_size]`. If it is a
|
||||
`TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
|
||||
If it is a (possibly nested) tuple of ints or `TensorShape`, this will
|
||||
be a tuple having the corresponding shapes. If cells are `LSTMCells`
|
||||
`state` will be a tuple containing a `LSTMStateTuple` for each cell.
|
||||
|
||||
Raises:
|
||||
TypeError: If `cell` is not an instance of RNNCell.
|
||||
ValueError: If inputs is None or an empty list.
|
||||
RuntimeError: If not using control flow v2.
|
||||
"""
|
||||
|
||||
# Currently only support time_major == True case.
|
||||
assert time_major
|
||||
|
||||
# TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or
|
||||
# TfLiteRNNCells.
|
||||
rnn_cell_impl.assert_like_rnncell("cell", cell)
|
||||
|
||||
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
raise RuntimeError("OpHint dynamic rnn only supports control flow v2.")
|
||||
|
||||
parent_first_child_input = [{
|
||||
"parent_ophint_input_index": 0,
|
||||
"first_child_ophint_input_index": 0
|
||||
}]
|
||||
parent_last_child_output = [{
|
||||
"parent_output_index": 0,
|
||||
# For LstmCell, the index is 2.
|
||||
# For RnnCell, the index is 1.
|
||||
# So we use -1 meaning it's the last one.
|
||||
"child_output_index": -1
|
||||
}]
|
||||
internal_children_input_output = [{
|
||||
"child_input_index": 0,
|
||||
# For LstmCell, the index is 2.
|
||||
# For RnnCell, the index is 1.
|
||||
# So we use -1 meaning it's the last one.
|
||||
"child_output_index": -1
|
||||
}]
|
||||
inputs_outputs_mappings = {
|
||||
"parent_first_child_input": parent_first_child_input,
|
||||
"parent_last_child_output": parent_last_child_output,
|
||||
"internal_children_input_output": internal_children_input_output
|
||||
}
|
||||
tflite_wrapper = lite.OpHint(
|
||||
"TfLiteDynamicRnn",
|
||||
level=2,
|
||||
children_inputs_mappings=inputs_outputs_mappings)
|
||||
with vs.variable_scope(scope or "rnn") as varscope:
|
||||
# Create a new scope in which the caching device is either
|
||||
# determined by the parent scope, or is set to place the cached
|
||||
# Variable using the same placement as for the rest of the RNN.
|
||||
if _should_cache():
|
||||
if varscope.caching_device is None:
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
|
||||
inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0)
|
||||
|
||||
# By default, time_major==False and inputs are batch-major: shaped
|
||||
# [batch, time, depth]
|
||||
# For internal calculations, we transpose to [time, batch, depth]
|
||||
flat_input = nest.flatten(inputs)
|
||||
|
||||
if not time_major:
|
||||
# (batch, time, depth) => (time, batch, depth)
|
||||
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
|
||||
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
|
||||
|
||||
parallel_iterations = parallel_iterations or 32
|
||||
if sequence_length is not None:
|
||||
sequence_length = math_ops.to_int32(sequence_length)
|
||||
if sequence_length.get_shape().rank not in (None, 1):
|
||||
raise ValueError(
|
||||
"sequence_length must be a vector of length batch_size, "
|
||||
"but saw shape: %s" % sequence_length.get_shape())
|
||||
sequence_length = array_ops.identity( # Just to find it in the graph.
|
||||
sequence_length,
|
||||
name="sequence_length")
|
||||
|
||||
batch_size = _best_effort_input_batch_size(flat_input)
|
||||
|
||||
if initial_state is not None:
|
||||
state = initial_state
|
||||
else:
|
||||
if not dtype:
|
||||
raise ValueError("If there is no initial_state, you must give a dtype.")
|
||||
if getattr(cell, "get_initial_state", None) is not None:
|
||||
state = cell.get_initial_state(
|
||||
inputs=None, batch_size=batch_size, dtype=dtype)
|
||||
else:
|
||||
state = cell.zero_state(batch_size, dtype)
|
||||
|
||||
def _assert_has_shape(x, shape):
|
||||
x_shape = array_ops.shape(x)
|
||||
packed_shape = array_ops.stack(shape)
|
||||
return control_flow_ops.Assert(
|
||||
math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
|
||||
"Expected shape for Tensor %s is " % x.name, packed_shape,
|
||||
" but saw shape: ", x_shape
|
||||
])
|
||||
|
||||
if not context.executing_eagerly() and sequence_length is not None:
|
||||
# Perform some shape validation
|
||||
with ops.control_dependencies(
|
||||
[_assert_has_shape(sequence_length, [batch_size])]):
|
||||
sequence_length = array_ops.identity(
|
||||
sequence_length, name="CheckSeqLen")
|
||||
|
||||
inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
|
||||
|
||||
outputs, final_state = _dynamic_rnn_loop(
|
||||
cell,
|
||||
inputs,
|
||||
state,
|
||||
parallel_iterations=parallel_iterations,
|
||||
swap_memory=swap_memory,
|
||||
sequence_length=sequence_length,
|
||||
dtype=dtype)
|
||||
|
||||
# Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
|
||||
# If we are performing batch-major calculations, transpose output back
|
||||
# to shape [batch, time, depth]
|
||||
if not time_major:
|
||||
# (time, batch, depth) => (batch, time, depth)
|
||||
outputs = nest.map_structure(_transpose_batch_time, outputs)
|
||||
outputs = tflite_wrapper.add_output(outputs, name="outputs")
|
||||
|
||||
return outputs, final_state
|
||||
|
@ -20,12 +20,14 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
from tensorflow.lite.experimental.examples.lstm.tflite_lstm import dynamic_rnn
|
||||
from tensorflow.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tools import optimize_for_inference_lib
|
||||
|
||||
|
||||
# Number of steps to train model.
|
||||
TRAIN_STEPS = 1
|
||||
|
||||
@ -67,7 +69,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
|
||||
TFLiteLSTMCell(self.num_units, forget_bias=0, name="rnn4")
|
||||
])
|
||||
|
||||
def buildModel(self, lstm_layer, is_dynamic_rnn, is_train):
|
||||
def buildModel(self, lstm_layer, is_dynamic_rnn):
|
||||
# Weights and biases for output softmax layer.
|
||||
out_weights = tf.Variable(
|
||||
tf.random_normal([self.num_units, self.n_classes]))
|
||||
@ -77,16 +79,11 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
|
||||
x = tf.placeholder(
|
||||
"float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE")
|
||||
|
||||
# For dynamic_rnn, train with dynamic_rnn and inference with static_rnn.
|
||||
# x is shaped [batch_size,time_steps,num_inputs]
|
||||
if is_dynamic_rnn:
|
||||
if is_train:
|
||||
lstm_input = x
|
||||
outputs, _ = tf.nn.dynamic_rnn(lstm_layer, lstm_input, dtype="float32")
|
||||
outputs = tf.unstack(outputs, axis=1)
|
||||
else:
|
||||
lstm_input = tf.unstack(x, self.time_steps, 1)
|
||||
outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32")
|
||||
lstm_input = tf.transpose(x, perm=[1, 0, 2])
|
||||
outputs, _ = dynamic_rnn(lstm_layer, lstm_input, dtype="float32")
|
||||
outputs = tf.unstack(outputs, axis=0)
|
||||
else:
|
||||
lstm_input = tf.unstack(x, self.time_steps, 1)
|
||||
outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32")
|
||||
@ -126,8 +123,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
|
||||
|
||||
# Reset the graph.
|
||||
tf.reset_default_graph()
|
||||
x, prediction, output_class = self.buildModel(
|
||||
lstm_layer, is_dynamic_rnn, is_train=False)
|
||||
x, prediction, output_class = self.buildModel(lstm_layer, is_dynamic_rnn)
|
||||
|
||||
new_sess = tf.Session(config=CONFIG)
|
||||
saver = tf.train.Saver()
|
||||
@ -159,6 +155,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
|
||||
|
||||
tflite = tf.lite.toco_convert(
|
||||
curr, [tflite_input], [outputs], allow_custom_ops=False)
|
||||
|
||||
interpreter = tf.lite.Interpreter(model_content=tflite)
|
||||
|
||||
try:
|
||||
@ -179,7 +176,7 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
|
||||
sess = tf.Session(config=CONFIG)
|
||||
|
||||
x, prediction, output_class = self.buildModel(
|
||||
self.buildLstmLayer(), is_dynamic_rnn=False, is_train=True)
|
||||
self.buildLstmLayer(), is_dynamic_rnn=False)
|
||||
self.trainModel(x, prediction, output_class, sess)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
@ -192,26 +189,15 @@ class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
|
||||
result = self.tfliteInvoke(frozen_graph, test_inputs, output_class)
|
||||
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def testDynamicRnnMultiRnnCell(self):
|
||||
sess = tf.Session(config=CONFIG)
|
||||
|
||||
x, prediction, output_class = self.buildModel(
|
||||
self.buildLstmLayer(), is_dynamic_rnn=True, is_train=True)
|
||||
self.buildLstmLayer(), is_dynamic_rnn=True)
|
||||
self.trainModel(x, prediction, output_class, sess)
|
||||
|
||||
# Since we don't yet support OpHints for dynamic, we will load the model
|
||||
# back in as a static model. This requires the variables to have the same
|
||||
# names as if they were trained as a static. Thus, we get rid of while/rnn
|
||||
# names.
|
||||
variables_to_save = {}
|
||||
for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
|
||||
op_name = i.name
|
||||
if op_name.startswith("while/rnn/"):
|
||||
op_name = op_name.split("while/rnn/")[1]
|
||||
if op_name.endswith(":0"):
|
||||
op_name = op_name.split(":0")[0]
|
||||
variables_to_save[op_name] = i
|
||||
saver = tf.train.Saver(variables_to_save)
|
||||
saver = tf.train.Saver()
|
||||
|
||||
x, prediction, output_class, new_sess = self.saveAndRestoreModel(
|
||||
self.buildLstmLayer(), sess, saver, is_dynamic_rnn=True)
|
||||
|
@ -71,6 +71,7 @@ from __future__ import print_function
|
||||
|
||||
import collections as _collections
|
||||
import copy as _copy
|
||||
import json as _json
|
||||
import uuid as _uuid
|
||||
import six as _six
|
||||
|
||||
@ -132,6 +133,14 @@ class OpHint(object):
|
||||
# "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
|
||||
# attribute to [2, 0, 1, -1].
|
||||
TFLITE_INPUT_INDICES = "_tflite_input_indices"
|
||||
# OpHint level.
|
||||
FUNCTION_LEVEL_ATTR = "_tflite_ophint_level"
|
||||
# Ophint internal mapping, this is for high level Ophint only.
|
||||
# This basically contains three kinds of mapping:
|
||||
# 1) How parental ophinted inputs map to the first child ophinted inputs;
|
||||
# 2) How internal children nodes are connected;
|
||||
# 3) How parental ophinted outputs map to the last child ophinted outputs.
|
||||
CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping"
|
||||
|
||||
# Types of aggregations
|
||||
# stack: stacks all ophints with matching tags. i.e. for a static rnn.
|
||||
@ -149,10 +158,16 @@ class OpHint(object):
|
||||
"""Conceptually tracks indices of arguments of "OpHint functions".
|
||||
|
||||
The inputs and arguments of these functions both use an instance
|
||||
of the class so they can have independent numbering."""
|
||||
of the class so they can have independent numbering.
|
||||
"""
|
||||
|
||||
def __init__(self, function_name, unique_function_id, node_name_prefix,
|
||||
attr_name):
|
||||
def __init__(self,
|
||||
function_name,
|
||||
unique_function_id,
|
||||
node_name_prefix,
|
||||
attr_name,
|
||||
level=1,
|
||||
children_inputs_mappings=None):
|
||||
"""Initialize ophint argument.
|
||||
|
||||
Args:
|
||||
@ -161,6 +176,8 @@ class OpHint(object):
|
||||
node_name_prefix: How identities that are created are named.
|
||||
attr_name: Name of attribute to use to store the index for this hint.
|
||||
i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
|
||||
level: Hierarchical level of the Ophint node, a number.
|
||||
children_inputs_mappings: Inputs/Outputs mapping for children hints.
|
||||
"""
|
||||
|
||||
# The global index is the argument index of the op. This is in contrast
|
||||
@ -176,6 +193,8 @@ class OpHint(object):
|
||||
self._tag_to_next_sort_index = {} # The current index for each tag
|
||||
self._node_name_prefix = node_name_prefix
|
||||
self._attr_name = attr_name
|
||||
self._level = level
|
||||
self._children_inputs_mappings = children_inputs_mappings
|
||||
|
||||
def _get_new_global_index(self, index_override):
|
||||
"""Return the next unused argument index in order or use an override.
|
||||
@ -251,6 +270,7 @@ class OpHint(object):
|
||||
uuid = self._unique_function_id
|
||||
name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
|
||||
uuid, global_index, sort_index, name)
|
||||
|
||||
identity_op = _array_ops.identity(arg, name=name)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@ -264,6 +284,15 @@ class OpHint(object):
|
||||
s=_compat.as_bytes(self._unique_function_id)))
|
||||
identity_op.op._set_attr(
|
||||
self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
|
||||
identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR,
|
||||
_attr_value_pb2.AttrValue(i=self._level))
|
||||
if self._children_inputs_mappings:
|
||||
identity_op.op._set_attr(
|
||||
OpHint.CHILDREN_INPUTS_MAPPINGS,
|
||||
_attr_value_pb2.AttrValue(
|
||||
s=_compat.as_bytes(_json.dumps(
|
||||
self._children_inputs_mappings))))
|
||||
|
||||
if sort_index is not None:
|
||||
identity_op.op._set_attr(
|
||||
OpHint.FUNCTION_SORT_INDEX_ATTR,
|
||||
@ -275,23 +304,74 @@ class OpHint(object):
|
||||
# pylint: enable=protected-access
|
||||
return identity_op
|
||||
|
||||
def __init__(self, function_name, **kwargs):
|
||||
def __init__(self,
|
||||
function_name,
|
||||
level=1,
|
||||
children_inputs_mappings=None,
|
||||
**kwargs):
|
||||
"""Create a OpHint.
|
||||
|
||||
Args:
|
||||
function_name: Name of the function (the custom op name in tflite)
|
||||
level: OpHint level.
|
||||
children_inputs_mappings: Children OpHint inputs/outputs mapping.
|
||||
children_inputs_mappings should like below:
|
||||
"parent_first_child_input":
|
||||
[{"parent_input_index": num, "child_input_index": num}, ...]
|
||||
"parent_last_child_output":
|
||||
[{"parent_output_index": num, "child_output_index": num}, ...]
|
||||
"internal_children_input_output":
|
||||
[{"child_input_index": num, "child_output_index": num}, ...]
|
||||
**kwargs: Keyword arguments of any constant attributes for the function.
|
||||
"""
|
||||
self._function_name = function_name
|
||||
self._level = level
|
||||
if self._level == 1:
|
||||
assert children_inputs_mappings is None
|
||||
else:
|
||||
assert isinstance(children_inputs_mappings, dict)
|
||||
self._children_inputs_mappings = children_inputs_mappings
|
||||
if self._children_inputs_mappings is not None:
|
||||
self._validate_children_inputs_mappings(self._children_inputs_mappings)
|
||||
self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough?
|
||||
self._attrs_to_store_later = kwargs
|
||||
self._stored_attrs = False
|
||||
self._inputs = OpHint.OpHintArgumentTracker(
|
||||
self._function_name, self._unique_function_id, "InputHint",
|
||||
OpHint.FUNCTION_INPUT_INDEX_ATTR)
|
||||
OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings)
|
||||
self._outputs = OpHint.OpHintArgumentTracker(
|
||||
self._function_name, self._unique_function_id, "OutputHint",
|
||||
OpHint.FUNCTION_OUTPUT_INDEX_ATTR)
|
||||
OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level,
|
||||
self._children_inputs_mappings)
|
||||
|
||||
def _validate_children_inputs_mappings(self, children_inputs_mappings):
|
||||
"""Validate children inputs mappings is in the right format.
|
||||
|
||||
Args:
|
||||
children_inputs_mappings: the Children ophint inputs/outputs mapping.
|
||||
"""
|
||||
assert isinstance(children_inputs_mappings, dict)
|
||||
assert "parent_first_child_input" in children_inputs_mappings
|
||||
assert "parent_last_child_output" in children_inputs_mappings
|
||||
assert "internal_children_input_output" in children_inputs_mappings
|
||||
|
||||
# validate parent_first_child_input.
|
||||
|
||||
def assert_dictlist_has_keys(dictlist, keys):
|
||||
for dikt in dictlist:
|
||||
assert isinstance(dikt, dict)
|
||||
for key in keys:
|
||||
assert key in dikt
|
||||
|
||||
assert_dictlist_has_keys(
|
||||
children_inputs_mappings["parent_first_child_input"],
|
||||
["parent_ophint_input_index", "first_child_ophint_input_index"])
|
||||
assert_dictlist_has_keys(
|
||||
children_inputs_mappings["parent_last_child_output"],
|
||||
["parent_output_index", "child_output_index"])
|
||||
assert_dictlist_has_keys(
|
||||
children_inputs_mappings["internal_children_input_output"],
|
||||
["child_input_index", "child_output_index"])
|
||||
|
||||
def _setattr(self, dest_op, name, value):
|
||||
tensor_value = _ops.convert_to_tensor(value)
|
||||
@ -382,7 +462,7 @@ class OpHint(object):
|
||||
|
||||
|
||||
class _LiteOperand(object):
|
||||
"""Abstract operand for a tflite hint function.
|
||||
"""Abstract operand for a tflite hint function._dynamic_rnn_loop.
|
||||
|
||||
This is a base class that handles representing arguments to an OpHint.
|
||||
It also is able to serialize operands to the stubbed graph_def.
|
||||
@ -580,15 +660,18 @@ class _LiteFuncCall(object):
|
||||
This is uses to accumulate found hints in the graphdef into a single
|
||||
conceptual unit.
|
||||
|
||||
Properties:
|
||||
self.inputs: inputs to the op (hash from index # to argument)
|
||||
self.outputs: outputs to the op (hash from index # to argument)
|
||||
self.function_name: the tflite custom op name to use
|
||||
self.uuid: a unique call id for this particular call (i.e.
|
||||
Attributes:
|
||||
inputs: inputs to the op (hash from index # to argument)
|
||||
outputs: outputs to the op (hash from index # to argument)
|
||||
function_name: the tflite custom op name to use
|
||||
uuid: a unique call id for this particular call (i.e.
|
||||
multiple function calls would have the same function_name but different
|
||||
uuids.
|
||||
self.params: A param name to key value for op constant data. I.e. for
|
||||
params: A param name to key value for op constant data. I.e. for
|
||||
axis on a reduction, strides on a convolution, etc.
|
||||
level: Level of the OpHint.
|
||||
children_inputs_mappings: If the Ophint has children, children inputs
|
||||
mappings indicate how their inputs & outputs are mapped.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@ -597,6 +680,8 @@ class _LiteFuncCall(object):
|
||||
self.function_name = None
|
||||
self.uuid = None
|
||||
self.params = {}
|
||||
self.level = -1
|
||||
self.children_inputs_mappings = {}
|
||||
|
||||
def flattened_inputs_and_outputs(self):
|
||||
"""Return a list of inputs and outputs in a flattened format.
|
||||
@ -622,22 +707,25 @@ class _LiteFuncCall(object):
|
||||
inputs_str = "\tInputs\n" + format_args(self.inputs)
|
||||
outputs_str = "\tOutputs\n" + format_args(self.outputs)
|
||||
|
||||
return ("tflite function %s call %s\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s"
|
||||
% (self.function_name, self.uuid, inputs_str, outputs_str))
|
||||
return (
|
||||
"tflite function %s call %s level %d "
|
||||
"\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" %
|
||||
(self.function_name, self.uuid, self.level, inputs_str, outputs_str))
|
||||
|
||||
|
||||
def _find_all_hints_in_graph_def(graphdef):
|
||||
"""Look at the current default graph and return a list of LiteFuncCall objs.
|
||||
def _find_all_hints_in_nodes(nodes):
|
||||
"""Look at the all the input nodes and return a list of LiteFuncCall objs.
|
||||
|
||||
Args:
|
||||
graphdef: A TensorFlow graph_def to look for LiteFuncCalls.
|
||||
nodes: A TensorFlow graph_def to look for LiteFuncCalls.
|
||||
|
||||
Returns:
|
||||
a list of `LifeFuncCall` objects in the form
|
||||
|
||||
"""
|
||||
func_calls = _collections.defaultdict(_LiteFuncCall)
|
||||
|
||||
for node in graphdef.node:
|
||||
for node in nodes:
|
||||
attr = node.attr
|
||||
# This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
|
||||
uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
|
||||
@ -649,6 +737,7 @@ def _find_all_hints_in_graph_def(graphdef):
|
||||
call_def = func_calls[uuid]
|
||||
call_def.uuid = uuid
|
||||
call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
|
||||
call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
|
||||
# Get sorting and aggregation information
|
||||
|
||||
sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
|
||||
@ -658,6 +747,10 @@ def _find_all_hints_in_graph_def(graphdef):
|
||||
if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
|
||||
aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
|
||||
|
||||
if OpHint.CHILDREN_INPUTS_MAPPINGS in attr:
|
||||
call_def.children_inputs_mappings = _json.loads(
|
||||
_compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s))
|
||||
|
||||
# Add the input or output
|
||||
def put_operand(stuff, index, sort, operand, aggregation):
|
||||
"""Add a given index into the function structure."""
|
||||
@ -683,6 +776,98 @@ def _find_all_hints_in_graph_def(graphdef):
|
||||
return func_calls
|
||||
|
||||
|
||||
def _extract_topology_sequence_mapping(nodes):
|
||||
return dict(
|
||||
(_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes))
|
||||
|
||||
|
||||
def _find_children_hints_in_while_loop(function_def, nodes_mapping):
|
||||
"""Find children hints and all nodes inside the while loop.
|
||||
|
||||
Args:
|
||||
function_def: Function def of the while loop.
|
||||
nodes_mapping: While loop input_arg : real node name.
|
||||
|
||||
Returns:
|
||||
Ordered children hints and all re-mapped nodes inside the while loop.
|
||||
"""
|
||||
new_nodes = []
|
||||
|
||||
# Make nodes inside function def inputs point to the real nodes.
|
||||
for node in function_def.node_def:
|
||||
for i in range(len(node.input)):
|
||||
if node.input[i] in nodes_mapping:
|
||||
node.input[i] = nodes_mapping[node.input[i]]
|
||||
new_nodes.append(_copy.deepcopy(node))
|
||||
name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def)
|
||||
children_hints = _find_all_hints_in_nodes(new_nodes)
|
||||
children_hints_q = []
|
||||
# Ordered by the outputs.
|
||||
for hint in _six.itervalues(children_hints):
|
||||
_, output_names = hint.flattened_inputs_and_outputs()
|
||||
seq = name_to_seq_num[output_names[0]]
|
||||
for output_name in output_names:
|
||||
seq = min(seq, name_to_seq_num[output_name])
|
||||
children_hints_q.append((seq, hint))
|
||||
children_hints_q.sort(key=lambda tup: tup[0])
|
||||
ordered_children_hints = [x[1] for x in children_hints_q]
|
||||
return ordered_children_hints, new_nodes
|
||||
|
||||
|
||||
def _find_children_hints(call, graph_def):
|
||||
"""Find all children hints.
|
||||
|
||||
For a given OpHint, we find all children hints inside it, we also copy all the
|
||||
nodes inside function defs (if applicable) to the original graph_def, they are
|
||||
returned in a list as well.
|
||||
|
||||
Args:
|
||||
call: Parent OpHint that contains children ophints.
|
||||
graph_def: Original graph def.
|
||||
|
||||
Returns:
|
||||
Ordered children hints inside the parent ophint; new graph def that contains
|
||||
nodes inside function defs (if applicable); nodes inside function defs.
|
||||
"""
|
||||
name_to_input_name, _, _ = _extract_graph_summary(graph_def)
|
||||
input_names, output_names = call.flattened_inputs_and_outputs()
|
||||
|
||||
reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
|
||||
reachable_by_output = _bfs_for_reachable_nodes(output_names,
|
||||
name_to_input_name)
|
||||
output_nodes_set = set(output_names)
|
||||
children_hints = []
|
||||
out = _graph_pb2.GraphDef()
|
||||
out.library.CopyFrom(graph_def.library)
|
||||
out.versions.CopyFrom(graph_def.versions)
|
||||
function_def_nodes = set()
|
||||
for node in graph_def.node:
|
||||
out.node.extend([_copy.deepcopy(node)])
|
||||
n = _tensor_name_base(node.name)
|
||||
if n in reachable_by_output:
|
||||
if n not in reachable_by_input and n not in output_nodes_set:
|
||||
# special handle for while loop function def.
|
||||
if node.op == "While":
|
||||
body_name = node.attr["body"].func.name
|
||||
inputs_outside_loop = node.input
|
||||
for function_def in graph_def.library.function:
|
||||
if function_def.signature.name == body_name:
|
||||
function_inputs = function_def.signature.input_arg
|
||||
assert len(inputs_outside_loop) == len(function_inputs)
|
||||
nodes_mapping = {}
|
||||
for i in range(len(function_inputs)):
|
||||
nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i]
|
||||
# TODO(b/123050804): Consider use grappler.
|
||||
(children_hints_in_loop,
|
||||
new_nodes) = _find_children_hints_in_while_loop(
|
||||
function_def, nodes_mapping)
|
||||
function_def_nodes.update([x.name for x in new_nodes])
|
||||
children_hints.extend(children_hints_in_loop)
|
||||
out.node.extend(new_nodes)
|
||||
|
||||
return children_hints, out, function_def_nodes
|
||||
|
||||
|
||||
def _tensor_name_base(full_tensor_name):
|
||||
"""Removes the device assignment code from a tensor.
|
||||
|
||||
@ -735,12 +920,20 @@ def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
|
||||
|
||||
|
||||
# TODO(aselle): This should be converted to grappler in the future.
|
||||
def _convert_single_op_hint_to_stub(call, graph_def):
|
||||
def _convert_single_op_hint_to_stub(call,
|
||||
graph_def,
|
||||
function_def_nodes=None,
|
||||
is_last_run=True):
|
||||
"""Given a graph_def, converts `call` into a stub and returns a new graph_def.
|
||||
|
||||
Args:
|
||||
call: A single function call to be converted.
|
||||
graph_def: A graph_def to use as input (that hass call obviously).
|
||||
function_def_nodes: Nodes inside the function def those are not connected to
|
||||
the graph.
|
||||
is_last_run: Whether it is the last run for a given pass (for OpHint has
|
||||
children).
|
||||
|
||||
Returns:
|
||||
A new transformed graph-def that has call as a stub (single op).
|
||||
|
||||
@ -748,6 +941,8 @@ def _convert_single_op_hint_to_stub(call, graph_def):
|
||||
the tensorflow runtime, so all future manipulations are done in graph_def
|
||||
level.
|
||||
"""
|
||||
if function_def_nodes is None:
|
||||
function_def_nodes = set()
|
||||
name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
|
||||
graph_def)
|
||||
input_names, output_names = call.flattened_inputs_and_outputs()
|
||||
@ -755,7 +950,6 @@ def _convert_single_op_hint_to_stub(call, graph_def):
|
||||
reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
|
||||
reachable_by_output = _bfs_for_reachable_nodes(output_names,
|
||||
name_to_input_name)
|
||||
input_nodes_set = set(input_names)
|
||||
output_nodes_set = set(output_names)
|
||||
nodes_after_fuse = []
|
||||
nodes_deleted_by_fuse = set()
|
||||
@ -766,19 +960,16 @@ def _convert_single_op_hint_to_stub(call, graph_def):
|
||||
n = _tensor_name_base(node.name)
|
||||
if n in reachable_by_output:
|
||||
if n not in reachable_by_input and n not in output_nodes_set:
|
||||
# n is an internal node. Check to make sure it is really internal.
|
||||
# TODO(aselle): this could be done more efficiently by flooding
|
||||
# the graph first.
|
||||
_check_subgraph_closed(n, reachable_by_input, input_nodes_set,
|
||||
name_to_input_name)
|
||||
nodes_deleted_by_fuse.add(n)
|
||||
elif n not in reachable_by_input:
|
||||
elif n not in reachable_by_input and n not in function_def_nodes:
|
||||
# n is a node that after all the fusings, so keep it.
|
||||
nodes_after_fuse.append(n)
|
||||
else:
|
||||
# n is a node that is randomly in the graph but not connected to
|
||||
# the chain of dependencies.
|
||||
pass
|
||||
# In the last run, n is a node that is randomly in the graph but not
|
||||
# connected to the chain of dependencies, we will delete n, otherwise
|
||||
# we keep them.
|
||||
if not is_last_run:
|
||||
nodes_after_fuse.append(n)
|
||||
|
||||
# Make a new graphdef with all the pre-input and input nodes
|
||||
out = _graph_pb2.GraphDef()
|
||||
@ -800,7 +991,8 @@ def _convert_single_op_hint_to_stub(call, graph_def):
|
||||
# non-fused things.
|
||||
for input_index in sorted_input_indices:
|
||||
inputs = call.inputs[input_index]
|
||||
new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
|
||||
input_name = inputs.aggregate_and_return_name_for_input(out)
|
||||
new_node.input.append(input_name)
|
||||
new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
|
||||
|
||||
# Ceate the function
|
||||
@ -936,6 +1128,18 @@ def _remove_redundant_stack_unstack(graph_def):
|
||||
return curr
|
||||
|
||||
|
||||
def _get_correct_mapping(original_index, nodes):
|
||||
# Special handle for the index is -1 case.
|
||||
# If it is -1, return the last index.
|
||||
if original_index == -1:
|
||||
node_indices = nodes.keys()
|
||||
node_indices.sort()
|
||||
return node_indices[-1]
|
||||
else:
|
||||
return original_index
|
||||
return original_index
|
||||
|
||||
|
||||
@_tf_export("lite.convert_op_hints_to_stubs")
|
||||
def _convert_op_hints_to_stubs_helper(
|
||||
graph_def, write_callback=lambda sess, graph_def: None):
|
||||
@ -948,14 +1152,67 @@ def _convert_op_hints_to_stubs_helper(
|
||||
Returns:
|
||||
A new stubbed graph_def.
|
||||
"""
|
||||
hints = _find_all_hints_in_nodes(graph_def.node)
|
||||
|
||||
hints_q = []
|
||||
for hint in _six.itervalues(hints):
|
||||
hints_q.append((hint.level, hint.uuid))
|
||||
|
||||
hints_q.sort(key=lambda tup: tup[0])
|
||||
for i in range(len(hints_q) - 1, -1, -1):
|
||||
level, hint_uuid = hints_q[i]
|
||||
|
||||
hints = _find_all_hints_in_graph_def(graph_def)
|
||||
curr_graph_def = graph_def
|
||||
del graph_def # prevent using graph_def again (common source of error)
|
||||
for hint in _six.itervalues(hints):
|
||||
curr_graph_def = _convert_single_op_hint_to_stub(
|
||||
hint, curr_graph_def)
|
||||
write_callback(curr_graph_def, "initial")
|
||||
for i in range(len(hints_q) - 1, -1, -1):
|
||||
level, hint_uuid = hints_q[i]
|
||||
if level >= 2:
|
||||
children_hints, curr_graph_def, function_def_nodes = _find_children_hints(
|
||||
hints[hint_uuid], curr_graph_def)
|
||||
# pylint: disable=superfluous-parens
|
||||
assert (len(children_hints) > 0) # pylint: disable=g-explicit-length-test
|
||||
# pylint: enable=superfluous-parens
|
||||
|
||||
# Re-wire the children hints inputs/outputs, so latter child's inputs
|
||||
# connect to previous child node's outputs.
|
||||
children_inputs_mappings = hints[hint_uuid].children_inputs_mappings
|
||||
for j in range(len(children_hints)):
|
||||
child_hint = children_hints[j]
|
||||
if j == 0:
|
||||
for mapping in children_inputs_mappings["parent_first_child_input"]:
|
||||
parent_input_index = _get_correct_mapping(
|
||||
mapping["parent_ophint_input_index"], hints[hint_uuid].inputs)
|
||||
child_input_index = _get_correct_mapping(
|
||||
mapping["first_child_ophint_input_index"], child_hint.inputs)
|
||||
child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[
|
||||
parent_input_index]
|
||||
else:
|
||||
for mapping in children_inputs_mappings[
|
||||
"internal_children_input_output"]:
|
||||
input_index = _get_correct_mapping(mapping["child_input_index"],
|
||||
child_hint.inputs)
|
||||
output_index = _get_correct_mapping(mapping["child_output_index"],
|
||||
children_hints[j - 1].outputs)
|
||||
child_hint.inputs[input_index] = children_hints[
|
||||
j - 1].outputs[output_index]
|
||||
if j == len(children_hints) - 1:
|
||||
for mapping in children_inputs_mappings["parent_last_child_output"]:
|
||||
parent_output_index = _get_correct_mapping(
|
||||
mapping["parent_output_index"], hints[hint_uuid].outputs)
|
||||
child_output_index = _get_correct_mapping(
|
||||
mapping["child_output_index"], child_hint.outputs)
|
||||
child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[
|
||||
parent_output_index]
|
||||
|
||||
for j in range(len(children_hints)):
|
||||
child_hint = children_hints[j]
|
||||
curr_graph_def = _convert_single_op_hint_to_stub(
|
||||
child_hint, curr_graph_def, function_def_nodes,
|
||||
j == len(children_hints) - 1)
|
||||
else:
|
||||
curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid],
|
||||
curr_graph_def)
|
||||
write_callback(curr_graph_def, "initial")
|
||||
# The stubbing process can create stacks/unstacks in the case of LSTMs
|
||||
# remove them.
|
||||
curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
|
||||
@ -982,9 +1239,9 @@ def find_all_hinted_output_nodes(session=None, graph_def=None):
|
||||
raise ValueError("Provide only one of session and graph_def.")
|
||||
hinted_outputs_nodes = []
|
||||
if session is not None:
|
||||
hints = _find_all_hints_in_graph_def(session.graph_def)
|
||||
hints = _find_all_hints_in_nodes(session.graph_def.node)
|
||||
elif graph_def is not None:
|
||||
hints = _find_all_hints_in_graph_def(graph_def)
|
||||
hints = _find_all_hints_in_nodes(graph_def.node)
|
||||
for hint in _six.itervalues(hints):
|
||||
_, ouput_nodes = hint.flattened_inputs_and_outputs()
|
||||
hinted_outputs_nodes.extend(ouput_nodes)
|
||||
|
@ -143,13 +143,14 @@ def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
|
||||
# Breadth first search to find all the nodes that we should keep.
|
||||
next_to_visit = target_nodes[:]
|
||||
while next_to_visit:
|
||||
n = next_to_visit[0]
|
||||
node = next_to_visit[0]
|
||||
del next_to_visit[0]
|
||||
if n in nodes_to_keep:
|
||||
if node in nodes_to_keep:
|
||||
# Already visited this node.
|
||||
continue
|
||||
nodes_to_keep.add(n)
|
||||
next_to_visit += name_to_input_name[n]
|
||||
nodes_to_keep.add(node)
|
||||
if node in name_to_input_name:
|
||||
next_to_visit += name_to_input_name[node]
|
||||
return nodes_to_keep
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add"
|
||||
|
@ -14,6 +14,10 @@ tf_class {
|
||||
name: "AGGREGATE_STACK"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "CHILDREN_INPUTS_MAPPINGS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "FUNCTION_AGGREGATE_ATTR"
|
||||
mtype: "<type \'str\'>"
|
||||
@ -22,6 +26,10 @@ tf_class {
|
||||
name: "FUNCTION_INPUT_INDEX_ATTR"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "FUNCTION_LEVEL_ATTR"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "FUNCTION_NAME_ATTR"
|
||||
mtype: "<type \'str\'>"
|
||||
@ -48,7 +56,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'function_name\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
argspec: "args=[\'self\', \'function_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_input"
|
||||
|
@ -4,7 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add"
|
||||
|
@ -14,6 +14,10 @@ tf_class {
|
||||
name: "AGGREGATE_STACK"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "CHILDREN_INPUTS_MAPPINGS"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "FUNCTION_AGGREGATE_ATTR"
|
||||
mtype: "<type \'str\'>"
|
||||
@ -22,6 +26,10 @@ tf_class {
|
||||
name: "FUNCTION_INPUT_INDEX_ATTR"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "FUNCTION_LEVEL_ATTR"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "FUNCTION_NAME_ATTR"
|
||||
mtype: "<type \'str\'>"
|
||||
@ -48,7 +56,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'function_name\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
argspec: "args=[\'self\', \'function_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_input"
|
||||
|
Loading…
x
Reference in New Issue
Block a user