Add support for TF op whitelisting in Keras Layers with Keras Tensors as

positional args, kwargs.

PiperOrigin-RevId: 262417758
This commit is contained in:
Thomas O'Malley 2019-08-08 13:13:48 -07:00 committed by Goldie Gadde
parent ff98617eb0
commit 8b21ae572e
4 changed files with 70 additions and 10 deletions

View File

@ -225,7 +225,8 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
# configured improperly.
constants[i] = op_input
else:
constants[i] = backend.function([], op_input)([])
with ops.init_scope():
constants[i] = backend.function([], op_input)([])
processed_ops, created_layers = _create_keras_history_helper(
layer_inputs, processed_ops, created_layers)
name = op.name
@ -239,7 +240,7 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
return processed_ops, created_layers
def needs_keras_history(tensors):
def needs_keras_history(tensors, ignore_call_context=False):
"""Check if any Tensors need to be wrapped in TensorFlowOpLayers.
This will never return True inside a sublayer, because sublayers
@ -249,12 +250,18 @@ def needs_keras_history(tensors):
Arguments:
tensors: An arbitrary nested structure of Tensors.
ignore_call_context: Whether to ignore the check of if currently
outside of a `call` context. This is `True` when creating
KerasHistory inside `Node`, where we always know that Tensors
are being used with the Functional API.
Returns:
Bool, whether at least one Tensor needs to be wrapped.
"""
input_tensors = nest.flatten(tensors)
if call_context().in_call or all(
if call_context().in_call and not ignore_call_context:
return False
if all(
getattr(tensor, '_keras_history', None) is not None
for tensor in input_tensors):
# KerasHistory already set.

View File

@ -1547,7 +1547,7 @@ class Network(base_layer.Layer):
def _get_min_depth(node):
"""Gets the minimum depth at which node can be computed."""
min_depth = 0
for layer, node_id, _, _ in node.iterate_inbound():
for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True):
inbound_node = layer._inbound_nodes[node_id]
if inbound_node in node_to_depth:
min_depth = min(min_depth, node_to_depth[inbound_node])
@ -1720,7 +1720,8 @@ def _map_graph_network(inputs, outputs):
nodes_in_progress.add(node)
# Propagate to all previous tensors connected to this node.
for layer, node_index, tensor_index, tensor in node.iterate_inbound():
for layer, node_index, tensor_index, tensor in node.iterate_inbound(
include_arguments=True):
build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index,
tensor_index)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.util import nest
@ -111,6 +112,15 @@ class Node(object):
# Optional keyword arguments to layer's `call`.
self.arguments = arguments
# Create Keras History for any Keras Tensors in `arguments`.
tensor_arguments = [
t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor)
]
for tensor_argument in tensor_arguments:
if base_layer_utils.needs_keras_history(
tensor_argument, ignore_call_context=True):
base_layer_utils.create_keras_history(tensor_argument)
# Add nodes to all layers involved.
for layer in nest.flatten(inbound_layers):
if layer is not None:
@ -121,15 +131,39 @@ class Node(object):
# accessor here.
outbound_layer.inbound_nodes.append(self)
def iterate_inbound(self):
def iterate_inbound(self, include_arguments=False):
"""Returns a list of tuples representing the inbound data.
Arguments:
include_arguments: Whether to also iterate over any Keras Tensors
passed as args, kwargs.
Returns:
List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
"""
return zip(
nest.flatten(self.inbound_layers), nest.flatten(self.node_indices),
nest.flatten(self.tensor_indices), nest.flatten(self.input_tensors))
inputs_inbound = list(
zip(
nest.flatten(self.inbound_layers),
nest.flatten(self.node_indices),
nest.flatten(self.tensor_indices),
nest.flatten(self.input_tensors)))
if include_arguments:
keras_tensor_arguments = [
kt for kt in nest.flatten(self.arguments)
if hasattr(kt, '_keras_history')
]
def _get_inbound(keras_tensor):
kh = keras_tensor._keras_history
return kh.layer, kh.node_index, kh.tensor_index, keras_tensor
arguments_inbound = nest.map_structure(_get_inbound,
keras_tensor_arguments)
return inputs_inbound + arguments_inbound
else:
return inputs_inbound
def _get_all_node_dependencies(self):
"""Returns all of the nodes this node immediately depends on."""

View File

@ -135,6 +135,19 @@ def _float64_op():
return keras.Model(inputs, outputs)
class MyAdd(keras.layers.Layer):
def call(self, x, y):
return x + y
def _layer_with_tensor_arg():
inputs = keras.Input(shape=(10,))
x = inputs * 2
outputs = MyAdd()(inputs, x)
return keras.Model(inputs, outputs)
class LayerWithLayer(keras.layers.Layer):
def build(self, input_shape):
@ -191,6 +204,7 @@ class AutoLambdaTest(keras_parameterized.TestCase):
('_float64_op', _float64_op),
('_inner_layer', _inner_layer),
('_reuse_ancillary_layer', _reuse_ancillary_layer),
('_layer_with_tensor_arg', _layer_with_tensor_arg),
)
def test_autolambda(self, model_fn):
model = model_fn()
@ -208,7 +222,11 @@ class AutoLambdaTest(keras_parameterized.TestCase):
model(np_inputs) # Test calling the model directly on inputs.
new_model = keras.Model.from_config(
model.get_config(), custom_objects={'LayerWithLayer': LayerWithLayer})
model.get_config(),
custom_objects={
'LayerWithLayer': LayerWithLayer,
'MyAdd': MyAdd
})
new_model.compile(
adam.Adam(0.001),
'mse',