diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 8e8dd3afce1..a4826e5b607 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -225,7 +225,8 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers): # configured improperly. constants[i] = op_input else: - constants[i] = backend.function([], op_input)([]) + with ops.init_scope(): + constants[i] = backend.function([], op_input)([]) processed_ops, created_layers = _create_keras_history_helper( layer_inputs, processed_ops, created_layers) name = op.name @@ -239,7 +240,7 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers): return processed_ops, created_layers -def needs_keras_history(tensors): +def needs_keras_history(tensors, ignore_call_context=False): """Check if any Tensors need to be wrapped in TensorFlowOpLayers. This will never return True inside a sublayer, because sublayers @@ -249,12 +250,18 @@ def needs_keras_history(tensors): Arguments: tensors: An arbitrary nested structure of Tensors. + ignore_call_context: Whether to ignore the check of if currently + outside of a `call` context. This is `True` when creating + KerasHistory inside `Node`, where we always know that Tensors + are being used with the Functional API. Returns: Bool, whether at least one Tensor needs to be wrapped. """ input_tensors = nest.flatten(tensors) - if call_context().in_call or all( + if call_context().in_call and not ignore_call_context: + return False + if all( getattr(tensor, '_keras_history', None) is not None for tensor in input_tensors): # KerasHistory already set. diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 534d73a137e..ff5a479a01a 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1547,7 +1547,7 @@ class Network(base_layer.Layer): def _get_min_depth(node): """Gets the minimum depth at which node can be computed.""" min_depth = 0 - for layer, node_id, _, _ in node.iterate_inbound(): + for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True): inbound_node = layer._inbound_nodes[node_id] if inbound_node in node_to_depth: min_depth = min(min_depth, node_to_depth[inbound_node]) @@ -1720,7 +1720,8 @@ def _map_graph_network(inputs, outputs): nodes_in_progress.add(node) # Propagate to all previous tensors connected to this node. - for layer, node_index, tensor_index, tensor in node.iterate_inbound(): + for layer, node_index, tensor_index, tensor in node.iterate_inbound( + include_arguments=True): build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index) diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py index 9a7ecb79c47..4e005071c6e 100644 --- a/tensorflow/python/keras/engine/node.py +++ b/tensorflow/python/keras/engine/node.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.keras import backend +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.util import nest @@ -111,6 +112,15 @@ class Node(object): # Optional keyword arguments to layer's `call`. self.arguments = arguments + # Create Keras History for any Keras Tensors in `arguments`. + tensor_arguments = [ + t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor) + ] + for tensor_argument in tensor_arguments: + if base_layer_utils.needs_keras_history( + tensor_argument, ignore_call_context=True): + base_layer_utils.create_keras_history(tensor_argument) + # Add nodes to all layers involved. for layer in nest.flatten(inbound_layers): if layer is not None: @@ -121,15 +131,39 @@ class Node(object): # accessor here. outbound_layer.inbound_nodes.append(self) - def iterate_inbound(self): + def iterate_inbound(self, include_arguments=False): """Returns a list of tuples representing the inbound data. + Arguments: + include_arguments: Whether to also iterate over any Keras Tensors + passed as args, kwargs. + Returns: List of tuples like: (inbound_layer, node_index, tensor_index, tensor). """ - return zip( - nest.flatten(self.inbound_layers), nest.flatten(self.node_indices), - nest.flatten(self.tensor_indices), nest.flatten(self.input_tensors)) + inputs_inbound = list( + zip( + nest.flatten(self.inbound_layers), + nest.flatten(self.node_indices), + nest.flatten(self.tensor_indices), + nest.flatten(self.input_tensors))) + + if include_arguments: + keras_tensor_arguments = [ + kt for kt in nest.flatten(self.arguments) + if hasattr(kt, '_keras_history') + ] + + def _get_inbound(keras_tensor): + kh = keras_tensor._keras_history + return kh.layer, kh.node_index, kh.tensor_index, keras_tensor + + arguments_inbound = nest.map_structure(_get_inbound, + keras_tensor_arguments) + + return inputs_inbound + arguments_inbound + else: + return inputs_inbound def _get_all_node_dependencies(self): """Returns all of the nodes this node immediately depends on.""" diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py index a853ce5eed0..a43e983bdfd 100644 --- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py +++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py @@ -135,6 +135,19 @@ def _float64_op(): return keras.Model(inputs, outputs) +class MyAdd(keras.layers.Layer): + + def call(self, x, y): + return x + y + + +def _layer_with_tensor_arg(): + inputs = keras.Input(shape=(10,)) + x = inputs * 2 + outputs = MyAdd()(inputs, x) + return keras.Model(inputs, outputs) + + class LayerWithLayer(keras.layers.Layer): def build(self, input_shape): @@ -191,6 +204,7 @@ class AutoLambdaTest(keras_parameterized.TestCase): ('_float64_op', _float64_op), ('_inner_layer', _inner_layer), ('_reuse_ancillary_layer', _reuse_ancillary_layer), + ('_layer_with_tensor_arg', _layer_with_tensor_arg), ) def test_autolambda(self, model_fn): model = model_fn() @@ -208,7 +222,11 @@ class AutoLambdaTest(keras_parameterized.TestCase): model(np_inputs) # Test calling the model directly on inputs. new_model = keras.Model.from_config( - model.get_config(), custom_objects={'LayerWithLayer': LayerWithLayer}) + model.get_config(), + custom_objects={ + 'LayerWithLayer': LayerWithLayer, + 'MyAdd': MyAdd + }) new_model.compile( adam.Adam(0.001), 'mse',