diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index f76854706e8..c79237bb727 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -133,6 +133,7 @@ class _DummyEagerGraph(threading.local): # get a different key. super(_DummyEagerGraph, self).__init__() self.key = _DummyEagerGraph._WeakReferencableClass() + self.learning_phase_is_set = False _DUMMY_EAGER_GRAPH = _DummyEagerGraph() @@ -295,6 +296,7 @@ def clear_session(): _SESSION.session = None graph = get_graph() with graph.as_default(): + _DUMMY_EAGER_GRAPH.learning_phase_is_set = False _GRAPH_LEARNING_PHASES.clear() # Create the learning phase placeholder in graph using the default factory. _GRAPH_LEARNING_PHASES.setdefault(graph) @@ -351,7 +353,7 @@ def learning_phase(): def global_learning_phase_is_set(): - return _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES + return _DUMMY_EAGER_GRAPH.learning_phase_is_set def _mark_func_graph_as_unsaveable(graph, learning_phase): @@ -420,6 +422,7 @@ def set_learning_phase(value): if context.executing_eagerly(): # In an eager context, the learning phase values applies to both the eager # context and the internal Keras graph. + _DUMMY_EAGER_GRAPH.learning_phase_is_set = True _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value _GRAPH_LEARNING_PHASES[get_graph()] = value @@ -451,11 +454,14 @@ def learning_phase_scope(value): _DUMMY_EAGER_GRAPH.key, None) previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None) + learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set try: set_learning_phase(value) yield finally: # Restore learning phase to initial value. + if not learning_phase_previously_set: + _DUMMY_EAGER_GRAPH.learning_phase_is_set = False with ops.init_scope(): if context.executing_eagerly(): if previous_eager_value is not None: diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 5a24be11f67..b010b1f2b18 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -495,6 +495,20 @@ tf_py_test( ], ) +tf_py_test( + name = "node_test", + size = "medium", + srcs = ["node_test.py"], + python_version = "PY3", + shard_count = 3, + deps = [ + ":base_layer", + ":engine", + "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/keras/utils:layer_utils", + ], +) + tf_py_test( name = "base_layer_test", size = "medium", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 7a6ba186bc3..7d12d1635ac 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections +import copy import functools import itertools import threading @@ -797,15 +797,22 @@ class Layer(module.Module, version_utils.LayerVersionSelector): raise RuntimeError( 'You must call `super().__init__()` in the layer constructor.') - # Grab the first positional or keyword argument. - if args: - inputs = args[0] - args = args[1:] - elif self._call_fn_args[0] in kwargs: - inputs = kwargs.pop(self._call_fn_args[0]) - else: - raise ValueError( - 'The first argument to `Layer.call` must always be passed.') + # 'inputs` (the first arg in the method spec) is special cased in + # layer call due to historical reasons. + # This special casing currently takes the form of: + # - 'inputs' must be explicitly passed. A layer cannot have zero arguments, + # and inputs cannot have been provided via the default value of a kwarg. + # - numpy/scalar values in `inputs` get converted to tensors + # - implicit masks / mask metadata are only collected from 'inputs` + # - Layers are built using shape info from 'inputs' only + # - input_spec compatibility is only checked against `inputs` + # - checking if a layer has ragged tensor support is only done against + # `inputs` + # - mixed precision casting (autocast) is only applied to `inputs`, + # not to any other argument. + # - configuring the Functional API SavedModel saving spec for deciding what + # should be serialized during SavedModel saving + inputs, args, kwargs = self._split_out_first_arg(args, kwargs) call_context = base_layer_utils.call_context() input_list = nest.flatten(inputs) @@ -814,6 +821,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # This is always the case in graph mode. It can also be the case in eager # mode when all inputs can be traced back to `keras.Input()` (when building # models using the functional API). + # TODO(kaftan): make this not special case inputs. Instead + # build a functional api model if *any* *arg or **kwarg is symbolic, + # even if part of the data structure in that arg is not symbolic. build_graph = tf_utils.are_all_symbolic_tensors(input_list) # Accept NumPy and scalar inputs by converting to Tensors. @@ -838,16 +848,17 @@ class Layer(module.Module, version_utils.LayerVersionSelector): mask_arg_passed_by_framework = True kwargs['mask'] = input_masks - # If `training` argument was not explicitly passed, propagate `training` - # value from this layer's calling layer. + # If `training` argument is None or not explicitly passed, + # propagate `training` value from this layer's calling layer. + training_value = None training_arg_passed_by_framework = False # Priority 1: `training` was explicitly passed. if self._call_arg_was_passed('training', args, kwargs): training_value = self._get_call_arg_value('training', args, kwargs) if not self._expects_training_arg: kwargs.pop('training') - else: - training_value = None + + if training_value is None: # Priority 2: `training` was passed to a parent layer. if call_context.training is not None: training_value = call_context.training @@ -867,12 +878,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector): training_value = math_ops.cast(training_value, dtypes.bool) else: training_value = bool(training_value) - kwargs['training'] = training_value + args, kwargs = self._set_call_arg_value( + 'training', training_value, args, kwargs) training_arg_passed_by_framework = True # Only create Keras history if at least one tensor originates from a # `keras.Input`. Otherwise this Layer may be being used outside the Keras # framework. + # TODO(kaftan): make this not special case inputs if build_graph and base_layer_utils.needs_keras_history(inputs): base_layer_utils.create_keras_history(inputs) @@ -953,13 +966,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector): raise ValueError('A layer\'s `call` method should return a ' 'Tensor or a list of Tensors, not None ' '(layer: ' + self.name + ').') + # TODO(kaftan): This should be 'any' and check all args if base_layer_utils.have_all_keras_metadata(inputs): if training_arg_passed_by_framework: - kwargs.pop('training') + args, kwargs = self._set_call_arg_value( + 'training', None, args, kwargs, pop_kwarg_if_none=True) if mask_arg_passed_by_framework: kwargs.pop('mask') - inputs, outputs = self._set_connectivity_metadata_( - inputs, outputs, args, kwargs) + # Node connectivity does not special-case the first argument. + outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, + outputs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) if hasattr(self, '_set_inputs') and not self.inputs: @@ -2299,70 +2315,45 @@ class Layer(module.Module, version_utils.LayerVersionSelector): args_dict = dict(zip(call_fn_args, args)) return args_dict[arg_name] - def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): + def _set_call_arg_value( + self, arg_name, new_value, args, + kwargs, inputs_in_args=False, pop_kwarg_if_none=False): + arg_pos = self._call_fn_arg_positions.get(arg_name, None) + if arg_pos is not None: + if not inputs_in_args: + # Ignore `inputs` arg. + arg_pos = arg_pos - 1 + if len(args) > arg_pos: + args = list(args) + args[arg_pos] = new_value + return args, kwargs + if new_value is None and pop_kwarg_if_none: + kwargs.pop(arg_name, None) + else: + kwargs[arg_name] = new_value + return args, kwargs - # If the layer returns tensors from its inputs, unmodified, - # we copy them to avoid loss of tensor metadata. - output_ls = nest.flatten(outputs) - inputs_ls = object_identity.ObjectIdentitySet(nest.flatten(inputs)) - output_ls_copy = [] - for x in output_ls: - if x in inputs_ls: + def _set_connectivity_metadata(self, args, kwargs, outputs): + # If the layer returns tensors from its inputs unmodified, + # we copy them to avoid loss of KerasHistory metadata. + flat_outputs = nest.flatten(outputs) + flat_inputs = nest.flatten((args, kwargs)) + inputs_set = object_identity.ObjectIdentitySet(flat_inputs) + outputs_copy = [] + for x in flat_outputs: + if x in inputs_set: with backend.name_scope(self.name): x = array_ops.identity(x) - output_ls_copy.append(x) - outputs = nest.pack_sequence_as(outputs, output_ls_copy) + outputs_copy.append(x) + outputs = nest.pack_sequence_as(outputs, outputs_copy) - # Ignore `inputs` arg. - arguments = dict(zip(self._call_fn_args[1:], args)) - arguments.update(kwargs) - - # Add an inbound node to the layer, so it can keep track of this call. - # This updates the layer history of the output tensor(s). - self._add_inbound_node( - input_tensors=inputs, output_tensors=outputs, arguments=arguments) - return inputs, outputs - - def _add_inbound_node(self, - input_tensors, - output_tensors, - arguments=None): - """Internal method to create an inbound node for the layer. - - Arguments: - input_tensors: list of input tensors. - output_tensors: list of output tensors. - arguments: dictionary of keyword arguments that were passed to the - `call` method of the layer at the call that created the node. - """ - inbound_layers = nest.map_structure(lambda t: t._keras_history.layer, - input_tensors) - node_indices = nest.map_structure(lambda t: t._keras_history.node_index, - input_tensors) - tensor_indices = nest.map_structure(lambda t: t._keras_history.tensor_index, - input_tensors) - - # Create node, add it to inbound nodes. - node_module.Node( - self, - inbound_layers=inbound_layers, - node_indices=node_indices, - tensor_indices=tensor_indices, - input_tensors=input_tensors, - output_tensors=output_tensors, - arguments=arguments) - - # Update tensor history metadata. - # The metadata attribute consists of - # 1) a layer instance - # 2) a node index for the layer - # 3) a tensor index for the node. - # The allows layer reuse (multiple nodes per layer) and multi-output - # or multi-input layers (e.g. a layer can return multiple tensors, - # and each can be sent to a different layer). - for i, tensor in enumerate(nest.flatten(output_tensors)): - tensor._keras_history = KerasHistory(self, - len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access + # Create node, Node wires itself to inbound and outbound layers. + # The Node constructor actually updates this layer's self._inbound_nodes, + # sets _keras_history on the outputs, and adds itself to the + # `_outbound_nodes` of the layers that produced the inputs to this + # layer call. + node_module.Node(self, call_args=args, call_kwargs=kwargs, outputs=outputs) + return outputs def _get_node_attribute_at_index(self, node_index, attr, attr_name): """Private utility to retrieves an attribute (e.g. inputs) from a node. @@ -2706,6 +2697,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector): return all_args[1:] return all_args + @property + @tracking.cached_per_instance + def _call_fn_arg_positions(self): + call_fn_arg_positions = dict() + for pos, arg in enumerate(self._call_fn_args): + call_fn_arg_positions[arg] = pos + return call_fn_arg_positions + @property @tracking.cached_per_instance def _call_accepts_kwargs(self): @@ -2743,6 +2742,21 @@ class Layer(module.Module, version_utils.LayerVersionSelector): seen_weights.add(w) return output + def _split_out_first_arg(self, args, kwargs): + # Grab the argument corresponding to the first argument in the + # layer's `call` method spec. This will either be the first positional + # argument, or it will be provided as a keyword argument. + if args: + inputs = args[0] + args = args[1:] + elif self._call_fn_args[0] in kwargs: + kwargs = copy.copy(kwargs) + inputs = kwargs.pop(self._call_fn_args[0]) + else: + raise ValueError( + 'The first argument to `Layer.call` must always be passed.') + return inputs, args, kwargs + # SavedModel properties. Please see keras/saving/saved_model for details. @property @@ -2952,31 +2966,6 @@ class AddMetric(Layer): return config -class KerasHistory( - collections.namedtuple('KerasHistory', - ['layer', 'node_index', 'tensor_index'])): - """Tracks the Layer call that created a Tensor, for Keras Graph Networks. - - During construction of Keras Graph Networks, this metadata is added to - each Tensor produced as the output of a Layer, starting with an - `InputLayer`. This allows Keras to track how each Tensor was produced, and - this information is later retraced by the `keras.engine.Network` class to - reconstruct the Keras Graph Network. - - Attributes: - layer: The Layer that produced the Tensor. - node_index: The specific call to the Layer that produced this Tensor. Layers - can be called multiple times in order to share weights. A new node is - created every time a Layer is called. - tensor_index: The output index for this Tensor. Always zero if the Layer - that produced this Tensor only has one output. Nested structures of - Tensors are deterministically assigned an index via `nest.flatten`. - """ - # Added to maintain memory and performance characteristics of `namedtuple` - # while subclassing. - __slots__ = () - - # Avoid breaking users who directly import this symbol from this file. # TODO(fchollet): remove this. InputSpec = input_spec.InputSpec # pylint:disable=invalid-name diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index a0e1e9edc2f..30ac17d8270 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -254,8 +254,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers): op_layer = base_layer.TensorFlowOpLayer( node_def, constants=constants, name=name) created_layers.append(op_layer) - op_layer._add_inbound_node( # pylint: disable=protected-access - layer_inputs, op.outputs) + op_layer._set_connectivity_metadata( # pylint: disable=protected-access + args=(layer_inputs,), + kwargs={}, + outputs=op.outputs) processed_ops.update([op]) return processed_ops, created_layers diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 7b4ce8ad54c..9d6f5afd240 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -45,7 +45,6 @@ from tensorflow.python.keras import regularizers from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import input_spec -from tensorflow.python.keras.engine import node as node_module from tensorflow.python.keras.mixed_precision.experimental import autocast_variable from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer from tensorflow.python.keras.mixed_precision.experimental import policy @@ -698,16 +697,17 @@ class Layer(base_layer.Layer): mask_arg_passed_by_framework = True kwargs['mask'] = input_masks - # If `training` argument was not explicitly passed, propagate `training` - # value from this layer's calling layer. + # If `training` argument is None or not explicitly passed, + # propagate `training` value from this layer's calling layer. + training_value = None training_arg_passed_by_framework = False # Priority 1: `training` was explicitly passed. if self._call_arg_was_passed('training', args, kwargs): training_value = self._get_call_arg_value('training', args, kwargs) if not self._expects_training_arg: kwargs.pop('training') - else: - training_value = None + + if training_value is None: # Priority 2: `training` was passed to a parent layer. if call_context.training is not None: training_value = call_context.training @@ -727,7 +727,8 @@ class Layer(base_layer.Layer): training_value = math_ops.cast(training_value, dtypes.bool) else: training_value = bool(training_value) - kwargs['training'] = training_value + args, kwargs = self._set_call_arg_value( + 'training', training_value, args, kwargs) training_arg_passed_by_framework = True # Only create Keras history if at least one tensor originates from a @@ -798,11 +799,12 @@ class Layer(base_layer.Layer): '(layer: ' + self.name + ').') if base_layer_utils.have_all_keras_metadata(inputs): if training_arg_passed_by_framework: - kwargs.pop('training') + args, kwargs = self._set_call_arg_value( + 'training', None, args, kwargs, pop_kwarg_if_none=True) if mask_arg_passed_by_framework: kwargs.pop('mask') - inputs, outputs = self._set_connectivity_metadata_( - inputs, outputs, args, kwargs) + outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, + outputs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) if hasattr(self, '_set_inputs') and not self.inputs: @@ -2005,70 +2007,23 @@ class Layer(base_layer.Layer): args_dict = dict(zip(call_fn_args, args)) return args_dict[arg_name] - def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): - - # If the layer returns tensors from its inputs, unmodified, - # we copy them to avoid loss of tensor metadata. - output_ls = nest.flatten(outputs) - inputs_ls = object_identity.ObjectIdentitySet(nest.flatten(inputs)) - output_ls_copy = [] - for x in output_ls: - if x in inputs_ls: - with backend.name_scope(self.name): - x = array_ops.identity(x) - output_ls_copy.append(x) - outputs = nest.pack_sequence_as(outputs, output_ls_copy) - - # Ignore `inputs` arg. - arguments = dict(zip(self._call_fn_args[1:], args)) - arguments.update(kwargs) - - # Add an inbound node to the layer, so it can keep track of this call. - # This updates the layer history of the output tensor(s). - self._add_inbound_node( - input_tensors=inputs, output_tensors=outputs, arguments=arguments) - return inputs, outputs - - def _add_inbound_node(self, - input_tensors, - output_tensors, - arguments=None): - """Internal method to create an inbound node for the layer. - - Arguments: - input_tensors: list of input tensors. - output_tensors: list of output tensors. - arguments: dictionary of keyword arguments that were passed to the - `call` method of the layer at the call that created the node. - """ - inbound_layers = nest.map_structure(lambda t: t._keras_history.layer, - input_tensors) - node_indices = nest.map_structure(lambda t: t._keras_history.node_index, - input_tensors) - tensor_indices = nest.map_structure(lambda t: t._keras_history.tensor_index, - input_tensors) - - # Create node, add it to inbound nodes. - node_module.Node( - self, - inbound_layers=inbound_layers, - node_indices=node_indices, - tensor_indices=tensor_indices, - input_tensors=input_tensors, - output_tensors=output_tensors, - arguments=arguments) - - # Update tensor history metadata. - # The metadata attribute consists of - # 1) a layer instance - # 2) a node index for the layer - # 3) a tensor index for the node. - # The allows layer reuse (multiple nodes per layer) and multi-output - # or multi-input layers (e.g. a layer can return multiple tensors, - # and each can be sent to a different layer). - for i, tensor in enumerate(nest.flatten(output_tensors)): - tensor._keras_history = KerasHistory(self, - len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access + def _set_call_arg_value( + self, arg_name, new_value, args, + kwargs, inputs_in_args=False, pop_kwarg_if_none=False): + arg_pos = self._call_fn_arg_positions.get(arg_name, None) + if arg_pos is not None: + if not inputs_in_args: + # Ignore `inputs` arg. + arg_pos = arg_pos - 1 + if len(args) > arg_pos: + args = list(args) + args[arg_pos] = new_value + return args, kwargs + if new_value is None and pop_kwarg_if_none: + kwargs.pop(arg_name, None) + else: + kwargs[arg_name] = new_value + return args, kwargs def _get_node_attribute_at_index(self, node_index, attr, attr_name): """Private utility to retrieves an attribute (e.g. inputs) from a node. @@ -2395,6 +2350,14 @@ class Layer(base_layer.Layer): return all_args[1:] return all_args + @property + @tracking.cached_per_instance + def _call_fn_arg_positions(self): + call_fn_arg_positions = dict() + for pos, arg in enumerate(self._call_fn_args): + call_fn_arg_positions[arg] = pos + return call_fn_arg_positions + @property @tracking.cached_per_instance def _call_accepts_kwargs(self): diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index efa3de959f1..1664e927da7 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -164,17 +164,9 @@ class InputLayer(base_layer.Layer): self.is_placeholder = False self._batch_input_shape = tuple(input_tensor.shape.as_list()) - # Create an input node to add to self.outbound_node - # and set output_tensors' _keras_history. - input_tensor._keras_history = base_layer.KerasHistory(self, 0, 0) + # Create an input node. input_tensor._keras_mask = None - node_module.Node( - self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=[input_tensor], - output_tensors=[input_tensor]) + node_module.Node(layer=self, outputs=input_tensor) def get_config(self): config = { @@ -294,8 +286,8 @@ def Input( # pylint: disable=invalid-name # Return tensor including `_keras_history`. # Note that in this case train_output and test_output are the same pointer. - outputs = input_layer._inbound_nodes[0].output_tensors - if len(outputs) == 1: + outputs = input_layer._inbound_nodes[0].outputs + if isinstance(outputs, list) and len(outputs) == 1: return outputs[0] else: return outputs diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 770f0468881..807576cb45b 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -25,13 +25,11 @@ import itertools import json import os -import numpy as np import six from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context from tensorflow.python.framework import composite_tensor -from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl from tensorflow.python.framework import func_graph @@ -42,7 +40,6 @@ from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import compile_utils from tensorflow.python.keras.engine import input_layer as input_layer_module -from tensorflow.python.keras.engine import node as node_module from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import save @@ -307,15 +304,6 @@ class Network(base_layer.Layer): self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) layer._attribute_sentinel.add_parent(self._attribute_sentinel) - # Create the node linking internal inputs to internal outputs. - node_module.Node( - outbound_layer=self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=self._nested_inputs, - output_tensors=self._nested_outputs) - # Build self.input_names and self.output_names. self._set_output_names() self.input_names = [] @@ -337,6 +325,82 @@ class Network(base_layer.Layer): self._compute_tensor_usage_count() self._set_save_spec(self._nested_inputs) + @property + def input(self): + """Retrieves the input tensor(s) of a layer. + + Only applicable if the layer has exactly one input, + i.e. if it is connected to one incoming layer. + + Returns: + Input tensor or list of input tensors. + + Raises: + RuntimeError: If called in Eager mode. + AttributeError: If no inbound nodes are found. + """ + if self._is_graph_network: + return self._nested_inputs + return super(Network, self).input + + @property + def input_shape(self): + """Retrieves the input shape(s) of a layer. + + Only applicable if the layer has exactly one input, + i.e. if it is connected to one incoming layer, or if all inputs + have the same shape. + + Returns: + Input shape, as an integer shape tuple + (or list of shape tuples, one tuple per input tensor). + + Raises: + AttributeError: if the layer has no defined input_shape. + RuntimeError: if called in Eager mode. + """ + if self._is_graph_network: + return nest.map_structure(backend.int_shape, self.input) + return super(Network, self).input_shape + + @property + def output(self): + """Retrieves the output tensor(s) of a layer. + + Only applicable if the layer has exactly one output, + i.e. if it is connected to one incoming layer. + + Returns: + Output tensor or list of output tensors. + + Raises: + AttributeError: if the layer is connected to more than one incoming + layers. + RuntimeError: if called in Eager mode. + """ + if self._is_graph_network: + return self._nested_outputs + return super(Network, self).output + + @property + def output_shape(self): + """Retrieves the output shape(s) of a layer. + + Only applicable if the layer has one output, + or if all outputs have the same shape. + + Returns: + Output shape, as an integer shape tuple + (or list of shape tuples, one tuple per output tensor). + + Raises: + AttributeError: if the layer has no defined output shape. + RuntimeError: if called in Eager mode. + """ + if self._is_graph_network: + return nest.map_structure(backend.int_shape, self.output) + return super(Network, self).output_shape + def _set_output_names(self): """Assigns unique names to the Network's outputs. @@ -700,8 +764,7 @@ class Network(base_layer.Layer): ' implement a `call` method.') return self._run_internal_graph( - inputs, training=training, mask=mask, - convert_kwargs_to_constants=base_layer_utils.call_context().saving) + inputs, training=training, mask=mask) def compute_output_shape(self, input_shape): if not self._is_graph_network: @@ -741,20 +804,20 @@ class Network(base_layer.Layer): for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer + layer = node.layer if layer in self._input_layers: # We've already covered the input layers # a few lines above. continue - # Potentially redundant list, - # same size as node.input_tensors. + # Get the input shapes for the first argument of the node layer_input_shapes = [] - for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound(): - input_layer_key = inbound_layer.name + '_%s_%s' % (node_id, - tensor_id) + layer_inputs = node.call_args[0] + for layer_input in nest.flatten(layer_inputs): + kh = layer_input._keras_history + input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index, + kh.tensor_index) layer_input_shapes.append(layers_to_output_shapes[input_layer_key]) - layer_input_shapes = nest.pack_sequence_as(node.inbound_layers, + layer_input_shapes = nest.pack_sequence_as(layer_inputs, layer_input_shapes) # Layers expect shapes to be tuples for `compute_output_shape`. layer_input_shapes = tf_utils.convert_shapes( @@ -782,8 +845,7 @@ class Network(base_layer.Layer): # Return shapes as TensorShapes. return output_shapes - def _run_internal_graph(self, inputs, training=None, mask=None, - convert_kwargs_to_constants=False): + def _run_internal_graph(self, inputs, training=None, mask=None): """Computes output tensors for new inputs. # Note: @@ -793,21 +855,10 @@ class Network(base_layer.Layer): inputs: Tensor or nested structure of Tensors. training: Boolean learning phase. mask: (Optional) Tensor or nested structure of Tensors. - convert_kwargs_to_constants: Whether to convert Tensor kwargs to - constants. This is used when tracing the model call function during - saving to ensure that external tensors aren't captured. Returns: Two lists: output_tensors, output_masks """ - # Note: masking support is relevant mainly for Keras. - # It cannot be factored out without having the fully reimplement the network - # calling logic on the Keras side. We choose to incorporate it in - # Network because 1) it may be useful to fully support in tf.layers in - # the future and 2) Keras is a major user of Network. If you don't - # use masking, it does not interfere with regular behavior at all and you - # can ignore it. - inputs = self._flatten_to_reference_inputs(inputs) if mask is None: masks = [None for _ in range(len(inputs))] @@ -825,58 +876,26 @@ class Network(base_layer.Layer): depth_keys = list(self._nodes_by_depth.keys()) depth_keys.sort(reverse=True) - # Ignore the InputLayers when computing the graph. - depth_keys = depth_keys[1:] for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer + if node.is_input: + continue # Input tensors already exist. - if all( + if not all( str(id(tensor)) in tensor_dict - for tensor in nest.flatten(node.input_tensors)): + for tensor in nest.flatten(node.keras_inputs)): + continue # Node is not computable, try skipping. - # Call layer (reapplying ops to new inputs). - computed_tensors = nest.map_structure( - lambda t: tensor_dict[str(id(t))].pop(), node.input_tensors) + layer = node.layer + args, kwargs = node.map_arguments(tensor_dict) + outputs = layer(*args, **kwargs) - # Ensure `training` arg propagation if applicable. - kwargs = copy.copy(node.arguments) if node.arguments else {} - if convert_kwargs_to_constants: - kwargs = _map_tensors_to_constants(kwargs) - - argspec = self._layer_call_argspecs[layer].args - if 'training' in argspec: - if 'training' not in kwargs or kwargs['training'] is None: - kwargs['training'] = training - elif (type(kwargs['training']) is ops.Tensor and # pylint: disable=unidiomatic-typecheck - any([ - kwargs['training'] is x - for x in backend._GRAPH_LEARNING_PHASES.values() - ])): - kwargs['training'] = training # Materialize placeholder. - - # Map Keras tensors in kwargs to their computed value. - def _map_tensor_if_from_keras_layer(t): - if (isinstance(t, - (ops.Tensor, composite_tensor.CompositeTensor)) and - hasattr(t, '_keras_history')): - t_id = str(id(t)) - return tensor_dict[t_id].pop() - return t - - kwargs = nest.map_structure(_map_tensor_if_from_keras_layer, kwargs) - - # Compute outputs. - output_tensors = layer(computed_tensors, **kwargs) - - # Update tensor_dict. - for x, y in zip( - nest.flatten(node.output_tensors), nest.flatten(output_tensors)): - x_id = str(id(x)) - tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] + # Update tensor_dict. + for x, y in zip(nest.flatten(node.outputs), nest.flatten(outputs)): + x_id = str(id(x)) + tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] output_tensors = [] output_shapes = [] @@ -1367,7 +1386,7 @@ class Network(base_layer.Layer): # pylint: disable=protected-access layer = x._keras_history.layer if len(layer._inbound_nodes) > 1 or ( - layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): + layer._inbound_nodes and not layer._inbound_nodes[0].is_input): cls_name = self.__class__.__name__ logging.warning(cls_name + ' inputs must come from ' '`tf.keras.Input` (thus holding past layer metadata), ' @@ -1437,7 +1456,7 @@ class Network(base_layer.Layer): def _get_min_depth(node): """Gets the minimum depth at which node can be computed.""" min_depth = 0 - for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True): + for layer, node_id, _, _ in node.iterate_inbound(): inbound_node = layer._inbound_nodes[node_id] if inbound_node in node_to_depth: min_depth = min(min_depth, node_to_depth[inbound_node]) @@ -1465,8 +1484,8 @@ class Network(base_layer.Layer): if depth is None: # Defer until inbound nodes are processed. unprocessed_nodes.append(node) continue - node_key = _make_node_key(node.outbound_layer.name, - node.outbound_layer._inbound_nodes.index(node)) + node_key = _make_node_key(node.layer.name, + node.layer._inbound_nodes.index(node)) if node_key not in self._network_nodes: node_to_depth[node] = depth self._network_nodes.add(node_key) @@ -1506,21 +1525,13 @@ class Network(base_layer.Layer): for depth in depth_keys: for node in self._nodes_by_depth[depth]: input_tensors = { - str(id(tensor)) for tensor in nest.flatten(node.input_tensors) + str(id(tensor)) for tensor in nest.flatten(node.keras_inputs) } if input_tensors.issubset(available_tensors): - kwargs = copy.copy(node.arguments) if node.arguments else {} - - for tensor in nest.flatten(kwargs): - if (isinstance(tensor, - (ops.Tensor, composite_tensor.CompositeTensor)) and - hasattr(tensor, '_keras_history')): - tensor_usage_count[str(id(tensor))] += 1 - - for tensor in nest.flatten(node.input_tensors): + for tensor in nest.flatten(node.keras_inputs): tensor_usage_count[str(id(tensor))] += 1 - for output_tensor in nest.flatten(node.output_tensors): + for output_tensor in nest.flatten(node.outputs): available_tensors.add(str(id(output_tensor))) for tensor in self.outputs: @@ -1631,97 +1642,35 @@ def _map_graph_network(inputs, outputs): Raises: ValueError: In case the network is not valid (e.g. disconnected graph). """ - # Network_nodes: set of nodes included in the graph of layers - # (not all nodes included in the layers are relevant to the current graph). - network_nodes = set() # ids of all nodes relevant to the Network + # "depth" is number of layers between output Node and the Node. + # Nodes are ordered from inputs -> outputs. + nodes_in_decreasing_depth, layer_indices = _build_map(outputs) + network_nodes = { + _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node)) + for node in nodes_in_decreasing_depth + } + nodes_depths = {} # dict {node: depth value} layers_depths = {} # dict {layer: depth value} - layer_indices = {} # dict {layer: index in traversal} - nodes_in_decreasing_depth = [] - - def build_map(tensor, - finished_nodes, - nodes_in_progress, - layer, - node_index, - tensor_index): - """Builds a map of the graph of layers. - - This recursively updates the map `layer_indices`, - the list `nodes_in_decreasing_depth` and the set `network_nodes`. - - Arguments: - tensor: Some tensor in a graph. - finished_nodes: Set of nodes whose subgraphs have been traversed - completely. Useful to prevent duplicated work. - nodes_in_progress: Set of nodes that are currently active on the - recursion stack. Useful to detect cycles. - layer: Layer from which `tensor` comes from. If not provided, - will be obtained from `tensor._keras_history`. - node_index: Node index from which `tensor` comes from. - tensor_index: Tensor_index from which `tensor` comes from. - - Raises: - ValueError: if a cycle is detected. - """ - node = layer._inbound_nodes[node_index] # pylint: disable=protected-access - - # Prevent cycles. - if node in nodes_in_progress: - raise ValueError('The tensor ' + str(tensor) + ' at layer "' + - layer.name + '" is part of a cycle.') - - # Don't repeat work for shared subgraphs - if node in finished_nodes: - return - - node_key = _make_node_key(layer.name, node_index) - # Update network_nodes. - network_nodes.add(node_key) - - # Store the traversal order for layer sorting. - if layer not in layer_indices: - layer_indices[layer] = len(layer_indices) - - nodes_in_progress.add(node) - - # Propagate to all previous tensors connected to this node. - for layer, node_index, tensor_index, tensor in node.iterate_inbound( - include_arguments=True): - build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, - tensor_index) - - finished_nodes.add(node) - nodes_in_progress.remove(node) - nodes_in_decreasing_depth.append(node) - - finished_nodes = set() - nodes_in_progress = set() - for x in outputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - build_map(x, finished_nodes, nodes_in_progress, - layer=layer, - node_index=node_index, - tensor_index=tensor_index) for node in reversed(nodes_in_decreasing_depth): # If the depth is not set, the node has no outbound nodes (depth 0). depth = nodes_depths.setdefault(node, 0) # Update the depth of the corresponding layer - previous_depth = layers_depths.get(node.outbound_layer, 0) + previous_depth = layers_depths.get(node.layer, 0) # If we've seen this layer before at a higher depth, # we should use that depth instead of the node depth. # This is necessary for shared layers that have inputs at different # depth levels in the graph. depth = max(depth, previous_depth) - layers_depths[node.outbound_layer] = depth + layers_depths[node.layer] = depth nodes_depths[node] = depth # Update the depth of inbound nodes. # The "depth" of a node is the max of the depths # of all nodes it is connected to + 1. - for node_dep in node._get_all_node_dependencies(): + for node_dep in node.parent_nodes: previous_depth = nodes_depths.get(node_dep, 0) nodes_depths[node_dep] = max(depth + 1, previous_depth) @@ -1773,9 +1722,9 @@ def _map_graph_network(inputs, outputs): layers_with_complete_input = [] # To provide a better error msg. for depth in depth_keys: for node in nodes_by_depth[depth]: - layer = node.outbound_layer - if layer: - for x in nest.flatten(node.input_tensors): + layer = node.layer + if layer and not node.is_input: + for x in nest.flatten(node.keras_inputs): if id(x) not in computable_tensors: raise ValueError('Graph disconnected: ' 'cannot obtain value for tensor ' + str(x) + @@ -1783,7 +1732,7 @@ def _map_graph_network(inputs, outputs): 'The following previous layers ' 'were accessed without issue: ' + str(layers_with_complete_input)) - for x in nest.flatten(node.output_tensors): + for x in nest.flatten(node.outputs): computable_tensors.add(id(x)) layers_with_complete_input.append(layer.name) @@ -1798,6 +1747,68 @@ def _map_graph_network(inputs, outputs): return network_nodes, nodes_by_depth, layers, layers_by_depth +def _build_map(outputs): + """This method topologically sorts nodes in order from inputs to outputs. + + It uses a depth-first search to topologically sort nodes that appear in the + _keras_history connectivity metadata of `outputs`. + + Args: + outputs: the output tensors whose _keras_history metadata should be walked. + This may be an arbitrary nested structure. + + Returns: + A tuple like (ordered_nodes, layer_to_first_traversal_index) + ordered_nodes: list of nodes appearing in the keras history, topologically + sorted from original inputs to the `outputs`. + (If outputs have different sets of ancestors, the inputs to one output + may appear after a different output). + layer_to_first_traversal_index: + A dict mapping layer to the traversal index in the DFS where it is + seen. Note: if a layer is shared by several nodes, the dict will only + store the index corresponding to the *first* time the layer seen. + """ + finished_nodes = set() + nodes_in_progress = set() + nodes_in_decreasing_depth = [] # nodes from inputs -> outputs. + layer_indices = {} # layer -> in traversal order. + for output in nest.flatten(outputs): + _build_map_helper(output, finished_nodes, nodes_in_progress, + nodes_in_decreasing_depth, layer_indices) + return nodes_in_decreasing_depth, layer_indices + + +def _build_map_helper(tensor, finished_nodes, nodes_in_progress, + nodes_in_decreasing_depth, layer_indices): + """Recursive helper for `_build_map`.""" + layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access + node = layer._inbound_nodes[node_index] # pylint: disable=protected-access + + # Don't repeat work for shared subgraphs + if node in finished_nodes: + return + + # Prevent cycles. + if node in nodes_in_progress: + raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name + + '" is part of a cycle.') + + # Store the traversal order for layer sorting. + if layer not in layer_indices: + layer_indices[layer] = len(layer_indices) + + # Propagate to all previous tensors connected to this node. + nodes_in_progress.add(node) + if not node.is_input: + for tensor in node.keras_inputs: + _build_map_helper(tensor, finished_nodes, nodes_in_progress, + nodes_in_decreasing_depth, layer_indices) + + finished_nodes.add(node) + nodes_in_progress.remove(node) + nodes_in_decreasing_depth.append(node) + + def _map_subgraph_network(inputs, outputs): """Returns the nodes and layers in the topology from `inputs` to `outputs`. @@ -1820,36 +1831,6 @@ def _should_skip_first_node(layer): return issubclass(layer.__class__, Network) and layer._is_graph_network -def _serialize_tensors(kwargs): - """Serializes Tensors passed to `call`.""" - - def _serialize_keras_tensor(t): - """Serializes a single Tensor passed to `call`.""" - if hasattr(t, '_keras_history'): - kh = t._keras_history - return [kh.layer.name, kh.node_index, kh.tensor_index] - - if isinstance(t, np.ndarray): - return t.tolist() - - if isinstance(t, ops.Tensor): - return backend.get_value(t).tolist() - - return t - - return nest.map_structure(_serialize_keras_tensor, kwargs) - - -def _map_tensors_to_constants(kwargs): - - def _map_to_constants(t): - if not hasattr(t, '_keras_history') and isinstance(t, ops.Tensor): - return constant_op.constant(backend.get_value(t)) - return t - - return nest.map_structure(_map_to_constants, kwargs) - - def _deserialize_keras_tensors(kwargs, layer_map): """Deserializes Keras Tensors passed to `call`..""" @@ -1863,7 +1844,7 @@ def _deserialize_keras_tensors(kwargs, layer_map): layer = layer_map[layer_name] node = layer._inbound_nodes[node_index] - return nest.flatten(node.output_tensors)[tensor_index] + return nest.flatten(node.outputs)[tensor_index] return t kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True) @@ -1961,7 +1942,7 @@ def reconstruct_from_config(config, custom_objects=None, created_layers=None): return inbound_node = inbound_layer._inbound_nodes[inbound_node_index] input_tensors.append( - nest.flatten(inbound_node.output_tensors)[inbound_tensor_index]) + nest.flatten(inbound_node.outputs)[inbound_tensor_index]) input_tensors = nest.pack_sequence_as(node_data, input_tensors) # Call layer on its inputs, thus creating the node # and building the layer if needed. @@ -2077,37 +2058,11 @@ def get_network_config(network, serialize_layer_fn=None): filtered_inbound_nodes = [] for original_node_index, node in enumerate(layer._inbound_nodes): node_key = _make_node_key(layer.name, original_node_index) - if node_key in network._network_nodes: + if node_key in network._network_nodes and not node.is_input: # The node is relevant to the model: # add to filtered_inbound_nodes. - if node.arguments: - kwargs = _serialize_tensors(node.arguments) - try: - json.dumps(kwargs) - except TypeError: - logging.warning( - 'Layer ' + layer.name + - ' was passed non-serializable keyword arguments: ' + - str(node.arguments) + '. They will not be included ' - 'in the serialized model (and thus will be missing ' - 'at deserialization time).') - kwargs = {} - else: - kwargs = {} - if node.inbound_layers: - node_data = [] - for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound(): - node_key = _make_node_key(inbound_layer.name, node_id) - new_node_index = node_conversion_map.get(node_key, 0) - node_data.append( - tf_utils.ListWrapper( - [inbound_layer.name, new_node_index, tensor_id, kwargs])) - node_data = nest.pack_sequence_as(node.input_tensors, node_data) - if not nest.is_sequence(node_data): - node_data = [node_data] - # Convert ListWrapper to list for backwards compatible configs. - node_data = tf_utils.convert_inner_node_data(node_data) - filtered_inbound_nodes.append(node_data) + node_data = node.serialize(_make_node_key, node_conversion_map) + filtered_inbound_nodes.append(node_data) layer_config = serialize_layer_fn(layer) layer_config['name'] = layer.name diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py index 429a096037f..7c19a3ae2bd 100644 --- a/tensorflow/python/keras/engine/network_test.py +++ b/tensorflow/python/keras/engine/network_test.py @@ -1785,6 +1785,7 @@ class AttrTrackingLayer(base_layer.Layer): return super(AttrTrackingLayer, self).dynamic +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) class CacheCorrectnessTest(keras_parameterized.TestCase): def layer_and_network_test(self): @@ -1919,8 +1920,12 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): class MyLayer(base_layer.Layer): def call(self, x, training=None): - self.training = training - return x + if training is None: + return x * -1.0 + elif training: + return x + else: + return x * 0.0 my_layer = MyLayer() x = np.ones((1, 10)) @@ -1929,9 +1934,8 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): outputs = my_layer(inputs, training=True) network = network_lib.Network(inputs, outputs) - network(x, training=False) # Hard-coded value passed during construction is respected. - self.assertTrue(my_layer.training) + self.assertAllEqual(network(x, training=False), x) inputs = input_layer_lib.Input(10) outputs = my_layer(inputs, training=False) @@ -1939,19 +1943,16 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): network(x, training=True) # Hard-coded value passed during construction is respected. - self.assertFalse(my_layer.training) + self.assertAllEqual(network(x, training=True), x * 0.0) inputs = input_layer_lib.Input(10) outputs = my_layer(inputs, training=None) network = network_lib.Network(inputs, outputs) - network(x, training=True) # `None` value passed during construction is overridden. - self.assertTrue(my_layer.training) - network(x, training=False) + self.assertAllEqual(network(x, training=True), x) # `None` value passed during construction is overridden. - self.assertFalse(my_layer.training) - + self.assertAllEqual(network(x, training=False), x * 0.0) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py index c8255764efc..945cf1c64bd 100644 --- a/tensorflow/python/keras/engine/node.py +++ b/tensorflow/python/keras/engine/node.py @@ -18,10 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import copy +import json +import numpy as np + from tensorflow.python.framework import ops from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import serialization class Node(object): @@ -33,163 +41,227 @@ class Node(object): a node is added to `layer._outbound_nodes`. Arguments: - outbound_layer: the layer that takes - `input_tensors` and turns them into `output_tensors` - (the node gets created when the `call` - method of the layer was called). - inbound_layers: a list of layers, the same length as `input_tensors`, - the layers from where `input_tensors` originate. - node_indices: a list of integers, the same length as `inbound_layers`. - `node_indices[i]` is the origin node of `input_tensors[i]` - (necessary since each inbound layer might have several nodes, - e.g. if the layer is being shared with a different data stream). - tensor_indices: a list of integers, - the same length as `inbound_layers`. - `tensor_indices[i]` is the index of `input_tensors[i]` within the - output of the inbound layer - (necessary since each inbound layer might - have multiple tensor outputs, with each one being - independently manipulable). - input_tensors: list of input tensors. - output_tensors: list of output tensors. - arguments: dictionary of keyword arguments that were passed to the - `call` method of the layer at the call that created the node. - - `node_indices` and `tensor_indices` are basically fine-grained coordinates - describing the origin of the `input_tensors`. - - A node from layer A to layer B is added to: - - A._outbound_nodes - - B._inbound_nodes + layer: The Layer for the Layer.__call__ this node represents. + call_args: The positional arguments the Layer was called with. + call_kwargs: The keyword arguments the Layer was called with. + outputs: The outputs of the Layer.__call__ """ def __init__(self, - outbound_layer, - inbound_layers, - node_indices, - tensor_indices, - input_tensors, - output_tensors, - arguments=None): - # Layer instance (NOT a sequence) - if isinstance(outbound_layer, (list, tuple, dict)): - raise ValueError('`outbound_layer` should be a layer instance, ' - 'not a list, tuple, or, dict.') + layer, + call_args=None, + call_kwargs=None, + outputs=None): + call_args = [] if call_args is None else call_args + call_kwargs = {} if call_kwargs is None else call_kwargs + outputs = [] if outputs is None else outputs - # These arguments are user-provided. Copy them here so that future - # user modifications do not affect the node's metadata. - input_tensors = nest.map_structure(lambda t: t, input_tensors) - output_tensors = nest.map_structure(lambda t: t, output_tensors) - arguments = nest.map_structure(lambda t: t, arguments) + self.layer = layer + self.is_input = not call_args and not call_kwargs - # this is the layer that takes a nested structure of input tensors - # and turns them into a nested structure of output tensors. - # the current node will be added to - # the inbound_nodes of outbound_layer. - self.outbound_layer = outbound_layer + # These arguments are user-provided. Copy the structures here so that + # future user modifications do not affect the node's metadata. + # We copy using map_structure rather than python's shallow or deep copy, + # because the args can be data structures (so shallow copy is + # insufficient), but individual values might not support copy.copy + # or be too expensive to deep copy. + call_args = nest.map_structure(lambda t: t, call_args) + call_kwargs = nest.map_structure(lambda t: t, call_kwargs) + self.outputs = nest.map_structure(lambda t: t, outputs) + self.call_args = call_args + self.call_kwargs = call_kwargs - # The following 3 properties describe where - # the input tensors come from: which layers, - # and for each layer, which node and which - # tensor output of each node. + # Cached for performance. + self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs)) - # Nested structure of layer instances. - self.inbound_layers = inbound_layers - # Nested structure of integers, 1:1 mapping with inbound_layers. - self.node_indices = node_indices - # Nested of integers, 1:1 mapping with inbound_layers. - self.tensor_indices = tensor_indices + # Create TensorFlowOpLayers if needed. + for obj in self._flat_arguments: + if (isinstance(obj, ops.Tensor) and + base_layer_utils.needs_keras_history(obj, ignore_call_context=True)): + base_layer_utils.create_keras_history(obj) - # Following 2 properties: - # tensor inputs and outputs of outbound_layer. + self._keras_inputs = [] + self._keras_inputs_ids_and_indices = [] + for i, ele in enumerate(self._flat_arguments): + if is_keras_tensor(ele): + self._keras_inputs.append(ele) + kt_id = str(id(ele)) + kt_index = i + self._keras_inputs_ids_and_indices.append((kt_id, kt_index)) - # Nested structure of tensors. 1:1 mapping with inbound_layers. - self.input_tensors = input_tensors - # Nested structure of tensors, created by outbound_layer.call(). - self.output_tensors = output_tensors + # Wire up Node to Layers. + self.layer._inbound_nodes.append(self) + for kt in self.keras_inputs: + inbound_layer = kt._keras_history.layer + if inbound_layer is not None: # `None` for `Input` tensors. + inbound_layer._outbound_nodes.append(self) - # Following 2 properties: input and output shapes. + # Set metadata on outputs. + node_index = len(self.layer._inbound_nodes) - 1 + for i, tensor in enumerate(nest.flatten(outputs)): + tensor._keras_history = KerasHistory( + layer=layer, node_index=node_index, tensor_index=i) - # Nested structure of shape tuples, shapes of input_tensors. - self.input_shapes = nest.map_structure(backend.int_shape, input_tensors) - # Nested structure of shape tuples, shapes of output_tensors. - self.output_shapes = nest.map_structure(backend.int_shape, output_tensors) + @property + def keras_inputs(self): + """Tensors input to this node that can be traced back to a `keras.Input`.""" + return self._keras_inputs - # Optional keyword arguments to layer's `call`. - self.arguments = arguments - - # Create Keras History for any Keras Tensors in `arguments`. - tensor_arguments = [ - t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor) - ] - for tensor_argument in tensor_arguments: - if base_layer_utils.needs_keras_history( - tensor_argument, ignore_call_context=True): - base_layer_utils.create_keras_history(tensor_argument) - - # Add nodes to all layers involved. - for layer in nest.flatten(inbound_layers): - if layer is not None: - # For compatibility with external Keras, we use the deprecated - # accessor here. - layer.outbound_nodes.append(self) - # For compatibility with external Keras, we use the deprecated - # accessor here. - outbound_layer.inbound_nodes.append(self) - - def iterate_inbound(self, include_arguments=False): - """Returns a list of tuples representing the inbound data. - - Arguments: - include_arguments: Whether to also iterate over any Keras Tensors - passed as args, kwargs. - - Returns: - List of tuples like: (inbound_layer, node_index, tensor_index, tensor). - """ - inputs_inbound = list( - zip( - nest.flatten(self.inbound_layers), - nest.flatten(self.node_indices), - nest.flatten(self.tensor_indices), - nest.flatten(self.input_tensors))) - - if include_arguments: - keras_tensor_arguments = [ - kt for kt in nest.flatten(self.arguments) - if hasattr(kt, '_keras_history') - ] - - def _get_inbound(keras_tensor): - kh = keras_tensor._keras_history - return kh.layer, kh.node_index, kh.tensor_index, keras_tensor - - arguments_inbound = nest.map_structure(_get_inbound, - keras_tensor_arguments) - - return inputs_inbound + arguments_inbound - else: - return inputs_inbound - - def _get_all_node_dependencies(self): - """Returns all of the nodes this node immediately depends on.""" + @property + def parent_nodes(self): + """Returns all the `Node`s whose output this node immediately depends on.""" node_deps = [] - for layer, node_index, _, _ in self.iterate_inbound(): - node_deps.append(layer._inbound_nodes[node_index]) - - for arg in nest.flatten(self.arguments): - if isinstance(arg, ops.Tensor) and hasattr(arg, '_keras_history'): - kh = arg._keras_history - node_deps.append(kh.layer._inbound_nodes[kh.node_index]) - + for kt in self.keras_inputs: + layer = kt._keras_history.layer + node_index = kt._keras_history.node_index + if layer is not None: # `None` for `Input` tensors. + node_deps.append(layer._inbound_nodes[node_index]) return node_deps - def get_config(self): - inbound_names = nest.map_structure( - lambda layer: layer.name if layer else None, self.inbound_layers) - return { - 'outbound_layer': self.outbound_layer.name, - 'inbound_layers': inbound_names, - 'node_indices': self.node_indices, - 'tensor_indices': self.tensor_indices - } + def iterate_inbound(self): + """Yields tuples representing the data inbound from other nodes. + + Yields: + tuples like: (inbound_layer, node_index, tensor_index, tensor). + """ + for kt in self.keras_inputs: + keras_history = kt._keras_history + layer = keras_history.layer + node_index = keras_history.node_index + tensor_index = keras_history.tensor_index + yield layer, node_index, tensor_index, kt + + def map_arguments(self, tensor_dict): + """Maps Keras Tensors to computed Tensors using `tensor_dict`.""" + flat_arguments = copy.copy(self._flat_arguments) + for kt_id, kt_index in self._keras_inputs_ids_and_indices: + flat_arguments[kt_index] = tensor_dict[kt_id].pop() + + args, kwargs = nest.pack_sequence_as( + (self.call_args, self.call_kwargs), flat_arguments) + return args, kwargs + + def serialize(self, make_node_key, node_conversion_map): + """Serializes `Node` for Functional API's `get_config`.""" + # Serialization still special-cases first argument. + args, kwargs = self.call_args, self.call_kwargs + inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs) + + # Treat everything other than first argument as a kwarg. + arguments = dict(zip(self.layer._call_fn_args[1:], args)) + arguments.update(kwargs) + kwargs = arguments + + kwargs = nest.map_structure(_serialize_keras_tensor, kwargs) + try: + json.dumps(kwargs, default=serialization.get_json_type) + except TypeError: + kwarg_types = nest.map_structure(type, kwargs) + logging.warning('Layer ' + self.layer.name + + ' was passed non-JSON-serializable arguments. ' + + 'Arguments had types: ' + + str(kwarg_types) + '. They will not be included ' + 'in the serialized model (and thus will be missing ' + 'at deserialization time).') + kwargs = {} + + # `kwargs` is added to each Tensor in the first arg. This should be + # changed in a future version of the serialization format. + def serialize_first_arg_tensor(t): + kh = t._keras_history + node_index = kh.node_index + node_key = make_node_key(kh.layer.name, node_index) + new_node_index = node_conversion_map.get(node_key, 0) + data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs] + return tf_utils.ListWrapper(data) + + data = nest.map_structure(serialize_first_arg_tensor, inputs) + if not nest.is_sequence(data): + data = [data] + data = tf_utils.convert_inner_node_data(data) + return data + + ############################################################# + # Properties for Backwards compatibility. + # These only check the first input argument + # As nodes are internal, they may be removed in the future. + ############################################################# + + @property + def input_tensors(self): + if self.is_input: + return [self.outputs] # Used in `Layer.input`. + return self.call_args[0] + + @property + def output_tensors(self): + if self.is_input: + return [self.outputs] # Used in `Layer.input`. + return self.outputs + + @property + def input_shapes(self): + input_shapes = nest.map_structure(backend.int_shape, self.input_tensors) + if len(input_shapes) == 1 and not self.is_input: + return input_shapes[0] + return input_shapes + + @property + def output_shapes(self): + return nest.map_structure(backend.int_shape, self.output_tensors) + + @property + def outbound_layer(self): + return self.layer + + @property + def inbound_layers(self): + if self.is_input: + return [] + inbound_layers = nest.map_structure(lambda t: t._keras_history.layer, + self.call_args[0]) + return inbound_layers + + +class KerasHistory( + collections.namedtuple('KerasHistory', + ['layer', 'node_index', 'tensor_index'])): + """Tracks the Layer call that created a Tensor, for Keras Graph Networks. + + During construction of Keras Graph Networks, this metadata is added to + each Tensor produced as the output of a Layer, starting with an + `InputLayer`. This allows Keras to track how each Tensor was produced, and + this information is later retraced by the `keras.engine.Network` class to + reconstruct the Keras Graph Network. + + Attributes: + layer: The Layer that produced the Tensor. + node_index: The specific call to the Layer that produced this Tensor. Layers + can be called multiple times in order to share weights. A new node is + created every time a Layer is called. + tensor_index: The output index for this Tensor. Always zero if the Layer + that produced this Tensor only has one output. Nested structures of + Tensors are deterministically assigned an index via `nest.flatten`. + """ + # Added to maintain memory and performance characteristics of `namedtuple` + # while subclassing. + __slots__ = () + + +def is_keras_tensor(obj): + return hasattr(obj, '_keras_history') + + +def _serialize_keras_tensor(t): + """Serializes a single Tensor passed to `call`.""" + if hasattr(t, '_keras_history'): + kh = t._keras_history + return [kh.layer.name, kh.node_index, kh.tensor_index] + + if isinstance(t, np.ndarray): + return t.tolist() + + if isinstance(t, ops.Tensor): + return backend.get_value(t).tolist() + + return t diff --git a/tensorflow/python/keras/engine/node_test.py b/tensorflow/python/keras/engine/node_test.py new file mode 100644 index 00000000000..80c5144da1b --- /dev/null +++ b/tensorflow/python/keras/engine/node_test.py @@ -0,0 +1,160 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#,============================================================================ +"""Tests for layer graphs construction & handling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import node as node_module +from tensorflow.python.platform import test + + +class DummyTensor(object): + + def __init__(self, shape=None): + self.shape = shape + + +class DummyLayer(base_layer.Layer): + pass + + +class NetworkConstructionTest(keras_parameterized.TestCase): + + def test_chained_node_construction(self): + # test basics + a = DummyTensor(shape=(None, 32)) + b = DummyTensor(shape=(None, 32)) + + a_layer = DummyLayer() + node = node_module.Node(a_layer, outputs=a) + self.assertEqual(node.outbound_layer, a_layer) + + self.assertTrue(node.is_input) + self.assertListEqual(node.inbound_layers, []) + self.assertListEqual(node.input_tensors, [a]) + self.assertListEqual(node.input_shapes, [(None, 32)]) + self.assertListEqual(node.output_tensors, [a]) + self.assertListEqual(node.output_shapes, [(None, 32)]) + + b_layer = DummyLayer() + node_module.Node(b_layer, outputs=b) + + dense = DummyLayer() + a_2 = DummyTensor() + node_a = node_module.Node(layer=dense, call_args=(a,), outputs=a_2) + b_2 = DummyTensor() + node_b = node_module.Node(layer=dense, call_args=(b,), outputs=b_2) + + # test the node attributes + self.assertFalse(node_a.is_input) + self.assertFalse(node_b.is_input) + self.assertEqual(node_a.call_args, (a,)) + self.assertEqual(node_a.call_kwargs, {}) + self.assertEqual(node_a.outputs, a_2) + + # Test the layer wiring + self.assertLen(dense._inbound_nodes, 2) + self.assertLen(dense._outbound_nodes, 0) + self.assertEqual(dense._inbound_nodes, [node_a, node_b]) + self.assertEqual(dense._inbound_nodes[0].inbound_layers, a_layer) + self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense) + self.assertEqual(dense._inbound_nodes[1].inbound_layers, b_layer) + self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense) + self.assertIs(dense._inbound_nodes[0].input_tensors, a) + self.assertIs(dense._inbound_nodes[1].input_tensors, b) + + def test_multi_input_node(self): + # test multi-input layer + a = DummyTensor() + b = DummyTensor() + + dense = DummyLayer() + a_2 = DummyTensor() + node_module.Node(layer=dense, call_args=(a,), outputs=a_2) + b_2 = DummyTensor() + node_module.Node(layer=dense, call_args=(b,), outputs=b_2) + + concat_layer = DummyLayer() + merged = DummyTensor() + node_module.Node(layer=concat_layer, call_args=([a_2, b_2],), + outputs=merged) + + merge_layer, merge_node_index, merge_tensor_index = merged._keras_history + + self.assertEqual(merge_node_index, 0) + self.assertEqual(merge_tensor_index, 0) + + self.assertLen(merge_layer._inbound_nodes, 1) + self.assertLen(merge_layer._outbound_nodes, 0) + + self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2) + self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a_2, b_2]) + self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2) + + def test_arg_and_kwarg_mix(self): + input_layer = DummyLayer() + input_layer_2 = DummyLayer() + a = DummyTensor() + node_a = node_module.Node(layer=input_layer, outputs=a) + b = DummyTensor() + node_b = node_module.Node(layer=input_layer_2, outputs=b) + + arg_2 = DummyTensor() + arg_3 = DummyTensor() + node_c = node_module.Node(layer=input_layer, outputs=arg_3) + + kwarg_x = DummyTensor() + kwarg_y = DummyTensor() + node_d = node_module.Node(layer=input_layer, outputs=kwarg_y) + + merge_layer = DummyLayer() + merged = DummyTensor() + node = node_module.Node(layer=merge_layer, + call_args=([a, b], arg_2, arg_3), + call_kwargs={'x': kwarg_x, 'y': kwarg_y}, + outputs=merged) + + merge_layer, merge_node_index, merge_tensor_index = merged._keras_history + + # Check the saved call args/kwargs + self.assertEqual(([a, b], arg_2, arg_3), node.call_args) + self.assertEqual({'x': kwarg_x, 'y': kwarg_y}, node.call_kwargs) + + # Only the inputs that were produced by input nodes should appear in + # keras_tensors + self.assertEqual({a, b, arg_3, kwarg_y}, set(node.keras_inputs)) + self.assertEqual(set(node.parent_nodes), {node_a, node_b, node_c, node_d}) + + # Check the layer wirings + self.assertEqual(merge_node_index, 0) + self.assertEqual(merge_tensor_index, 0) + self.assertLen(merge_layer._inbound_nodes, 1) + self.assertLen(merge_layer._outbound_nodes, 0) + self.assertLen(input_layer._outbound_nodes, 3) + self.assertLen(input_layer_2._outbound_nodes, 1) + + # The 'backwards compatibility' attributes should only check the + # first call argument + self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2) + self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a, b]) + self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 9edfa4f958b..2d5abac7fd6 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -215,7 +215,7 @@ class Sequential(training.Model): set_inputs = True if set_inputs: - outputs = nest.flatten(layer._inbound_nodes[-1].output_tensors) + outputs = nest.flatten(layer._inbound_nodes[-1].outputs) if len(outputs) != 1: raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) self.outputs = outputs diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index 0cbb70109fc..a93bba0271b 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib - from absl.testing import parameterized import numpy as np @@ -302,11 +300,11 @@ class CorrectnessTest(keras_parameterized.TestCase): self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4) @parameterized.named_parameters([ - ('_None', contextlib.contextmanager(lambda: iter([None])), 0., 4.), - ('_0', lambda: keras.backend.learning_phase_scope(0), 4., 4.), - ('_1', lambda: keras.backend.learning_phase_scope(1), 0., 0.), + ('_None', None, 0., 4.), + ('_False', False, 4., 4.), + ('_True', True, 0., 0.), ]) - def test_nested_model_learning_phase(self, nested_scope_fn, + def test_nested_model_learning_phase(self, training, expected_training_loss, expected_validation_loss): """Tests that learning phase is correctly set in an intermediate layer.""" @@ -326,18 +324,17 @@ class CorrectnessTest(keras_parameterized.TestCase): return keras.Model(inputs, outputs) def _regularize_model(unregularized_model): - inputs = keras.Input(unregularized_model.inputs[0].shape[1:]) - with nested_scope_fn(): - logits = unregularized_model(inputs) - outputs = keras.activations.softmax(logits) - model = keras.Model(inputs, outputs) # Regularize the most recent activations of a post-dropout layer. sample_activations = unregularized_model.get_layer( index=-2).get_output_at(-1) regularization_loss = keras.backend.mean(sample_activations) - model.add_loss(regularization_loss) - model.add_metric( + unregularized_model.add_loss(regularization_loss) + unregularized_model.add_metric( regularization_loss, aggregation='mean', name='regularization_loss') + inputs = keras.Input(unregularized_model.inputs[0].shape[1:]) + logits = unregularized_model(inputs, training=training) + outputs = keras.activations.softmax(logits) + model = keras.Model(inputs, outputs) return model # Make and compile models. diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 0b0121f521e..eaffb90e64b 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -112,11 +112,12 @@ def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map): # then call node.inbound_layer on them. if all( tensor in tensor_map for tensor in nest.flatten(node.input_tensors)): - computed_tensors = nest.map_structure(lambda t: tensor_map[t], - node.input_tensors) # Call layer. - kwargs = node.arguments or {} - output_tensors = layer(computed_tensors, **kwargs) + args = nest.map_structure(lambda t: tensor_map.get(t, t), + node.call_args) + kwargs = nest.map_structure(lambda t: tensor_map.get(t, t), + node.call_kwargs) + output_tensors = layer(*args, **kwargs) # Thread-safe way to keep track of what node was created. first_output_tensor = nest.flatten(output_tensors)[0] diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index f56d55b18d5..6c2037c4f4b 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -547,7 +547,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): else: return inputs - t = array_ops.sequence_mask(1) + t = self.evaluate(array_ops.sequence_mask(1)) inputs = keras.layers.Input(shape=(3)) model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t)) diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index 1dfd2f517c6..d2d3d919fff 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -55,7 +55,7 @@ def get_source_inputs(tensor, layer=None, node_index=None): return [tensor] else: node = layer._inbound_nodes[node_index] - if not node.inbound_layers: + if node.is_input: # Reached an Input layer, stop recursion. return nest.flatten(node.input_tensors) else: @@ -140,7 +140,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): nodes = [] for v in nodes_by_depth: if (len(v) > 1) or (len(v) == 1 and - len(nest.flatten(v[0].inbound_layers)) > 1): + len(nest.flatten(v[0].keras_inputs)) > 1): # if the model has multiple nodes # or if the nodes have multiple inbound_layers # the model is no longer sequential