From 9485250f44fb504396306156d3f9dfc2c09f1c29 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan Date: Sat, 25 Apr 2020 13:43:51 -0700 Subject: [PATCH] Stops special-casing first argument in Keras's internal Node representation. This makes the special-casing-behavior in __call__ more explicit, and make it easier to reduce how much special casing of the first argument Keras does moving forward. It also allows us to simplify internal node-walking code. This may come with a subtle behavior change/simplification in the interaction between `training`, learning phase scopes and model construction. Before, it seems that constructing a network model in a learning phase scope would cause that model to permanently use that learning phase in some future code regardless of what values of `training` were passed in. The code underlying this behavior seems to have been pretty fragile and hard to guarantee. It was also quite likely unintentional or buggy in many cases. After this change, setting a learning phase scope will continue affecting call-time behavior when `training` isn't explicitly passed in, but building a model inside of a learning phase scope won't force it to always use that learning phase. (Passing `training=...` at construction time will continue to be the recommended method of freezing behavior at call time) PiperOrigin-RevId: 308436442 Change-Id: I8cb8922a6e3cd219a1771328dda3003287978b39 --- tensorflow/python/keras/backend.py | 8 +- tensorflow/python/keras/engine/BUILD | 14 + tensorflow/python/keras/engine/base_layer.py | 195 ++++---- .../python/keras/engine/base_layer_utils.py | 6 +- .../python/keras/engine/base_layer_v1.py | 109 ++--- tensorflow/python/keras/engine/input_layer.py | 16 +- tensorflow/python/keras/engine/network.py | 419 ++++++++---------- .../python/keras/engine/network_test.py | 21 +- tensorflow/python/keras/engine/node.py | 362 +++++++++------ tensorflow/python/keras/engine/node_test.py | 160 +++++++ tensorflow/python/keras/engine/sequential.py | 2 +- .../keras/engine/training_eager_test.py | 23 +- tensorflow/python/keras/models.py | 9 +- .../saving/saved_model/saved_model_test.py | 2 +- tensorflow/python/keras/utils/layer_utils.py | 4 +- 15 files changed, 751 insertions(+), 599 deletions(-) create mode 100644 tensorflow/python/keras/engine/node_test.py diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index f76854706e8..c79237bb727 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -133,6 +133,7 @@ class _DummyEagerGraph(threading.local): # get a different key. super(_DummyEagerGraph, self).__init__() self.key = _DummyEagerGraph._WeakReferencableClass() + self.learning_phase_is_set = False _DUMMY_EAGER_GRAPH = _DummyEagerGraph() @@ -295,6 +296,7 @@ def clear_session(): _SESSION.session = None graph = get_graph() with graph.as_default(): + _DUMMY_EAGER_GRAPH.learning_phase_is_set = False _GRAPH_LEARNING_PHASES.clear() # Create the learning phase placeholder in graph using the default factory. _GRAPH_LEARNING_PHASES.setdefault(graph) @@ -351,7 +353,7 @@ def learning_phase(): def global_learning_phase_is_set(): - return _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES + return _DUMMY_EAGER_GRAPH.learning_phase_is_set def _mark_func_graph_as_unsaveable(graph, learning_phase): @@ -420,6 +422,7 @@ def set_learning_phase(value): if context.executing_eagerly(): # In an eager context, the learning phase values applies to both the eager # context and the internal Keras graph. + _DUMMY_EAGER_GRAPH.learning_phase_is_set = True _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value _GRAPH_LEARNING_PHASES[get_graph()] = value @@ -451,11 +454,14 @@ def learning_phase_scope(value): _DUMMY_EAGER_GRAPH.key, None) previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None) + learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set try: set_learning_phase(value) yield finally: # Restore learning phase to initial value. + if not learning_phase_previously_set: + _DUMMY_EAGER_GRAPH.learning_phase_is_set = False with ops.init_scope(): if context.executing_eagerly(): if previous_eager_value is not None: diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 5a24be11f67..b010b1f2b18 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -495,6 +495,20 @@ tf_py_test( ], ) +tf_py_test( + name = "node_test", + size = "medium", + srcs = ["node_test.py"], + python_version = "PY3", + shard_count = 3, + deps = [ + ":base_layer", + ":engine", + "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/keras/utils:layer_utils", + ], +) + tf_py_test( name = "base_layer_test", size = "medium", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 7a6ba186bc3..7d12d1635ac 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections +import copy import functools import itertools import threading @@ -797,15 +797,22 @@ class Layer(module.Module, version_utils.LayerVersionSelector): raise RuntimeError( 'You must call `super().__init__()` in the layer constructor.') - # Grab the first positional or keyword argument. - if args: - inputs = args[0] - args = args[1:] - elif self._call_fn_args[0] in kwargs: - inputs = kwargs.pop(self._call_fn_args[0]) - else: - raise ValueError( - 'The first argument to `Layer.call` must always be passed.') + # 'inputs` (the first arg in the method spec) is special cased in + # layer call due to historical reasons. + # This special casing currently takes the form of: + # - 'inputs' must be explicitly passed. A layer cannot have zero arguments, + # and inputs cannot have been provided via the default value of a kwarg. + # - numpy/scalar values in `inputs` get converted to tensors + # - implicit masks / mask metadata are only collected from 'inputs` + # - Layers are built using shape info from 'inputs' only + # - input_spec compatibility is only checked against `inputs` + # - checking if a layer has ragged tensor support is only done against + # `inputs` + # - mixed precision casting (autocast) is only applied to `inputs`, + # not to any other argument. + # - configuring the Functional API SavedModel saving spec for deciding what + # should be serialized during SavedModel saving + inputs, args, kwargs = self._split_out_first_arg(args, kwargs) call_context = base_layer_utils.call_context() input_list = nest.flatten(inputs) @@ -814,6 +821,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # This is always the case in graph mode. It can also be the case in eager # mode when all inputs can be traced back to `keras.Input()` (when building # models using the functional API). + # TODO(kaftan): make this not special case inputs. Instead + # build a functional api model if *any* *arg or **kwarg is symbolic, + # even if part of the data structure in that arg is not symbolic. build_graph = tf_utils.are_all_symbolic_tensors(input_list) # Accept NumPy and scalar inputs by converting to Tensors. @@ -838,16 +848,17 @@ class Layer(module.Module, version_utils.LayerVersionSelector): mask_arg_passed_by_framework = True kwargs['mask'] = input_masks - # If `training` argument was not explicitly passed, propagate `training` - # value from this layer's calling layer. + # If `training` argument is None or not explicitly passed, + # propagate `training` value from this layer's calling layer. + training_value = None training_arg_passed_by_framework = False # Priority 1: `training` was explicitly passed. if self._call_arg_was_passed('training', args, kwargs): training_value = self._get_call_arg_value('training', args, kwargs) if not self._expects_training_arg: kwargs.pop('training') - else: - training_value = None + + if training_value is None: # Priority 2: `training` was passed to a parent layer. if call_context.training is not None: training_value = call_context.training @@ -867,12 +878,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector): training_value = math_ops.cast(training_value, dtypes.bool) else: training_value = bool(training_value) - kwargs['training'] = training_value + args, kwargs = self._set_call_arg_value( + 'training', training_value, args, kwargs) training_arg_passed_by_framework = True # Only create Keras history if at least one tensor originates from a # `keras.Input`. Otherwise this Layer may be being used outside the Keras # framework. + # TODO(kaftan): make this not special case inputs if build_graph and base_layer_utils.needs_keras_history(inputs): base_layer_utils.create_keras_history(inputs) @@ -953,13 +966,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector): raise ValueError('A layer\'s `call` method should return a ' 'Tensor or a list of Tensors, not None ' '(layer: ' + self.name + ').') + # TODO(kaftan): This should be 'any' and check all args if base_layer_utils.have_all_keras_metadata(inputs): if training_arg_passed_by_framework: - kwargs.pop('training') + args, kwargs = self._set_call_arg_value( + 'training', None, args, kwargs, pop_kwarg_if_none=True) if mask_arg_passed_by_framework: kwargs.pop('mask') - inputs, outputs = self._set_connectivity_metadata_( - inputs, outputs, args, kwargs) + # Node connectivity does not special-case the first argument. + outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, + outputs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) if hasattr(self, '_set_inputs') and not self.inputs: @@ -2299,70 +2315,45 @@ class Layer(module.Module, version_utils.LayerVersionSelector): args_dict = dict(zip(call_fn_args, args)) return args_dict[arg_name] - def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): + def _set_call_arg_value( + self, arg_name, new_value, args, + kwargs, inputs_in_args=False, pop_kwarg_if_none=False): + arg_pos = self._call_fn_arg_positions.get(arg_name, None) + if arg_pos is not None: + if not inputs_in_args: + # Ignore `inputs` arg. + arg_pos = arg_pos - 1 + if len(args) > arg_pos: + args = list(args) + args[arg_pos] = new_value + return args, kwargs + if new_value is None and pop_kwarg_if_none: + kwargs.pop(arg_name, None) + else: + kwargs[arg_name] = new_value + return args, kwargs - # If the layer returns tensors from its inputs, unmodified, - # we copy them to avoid loss of tensor metadata. - output_ls = nest.flatten(outputs) - inputs_ls = object_identity.ObjectIdentitySet(nest.flatten(inputs)) - output_ls_copy = [] - for x in output_ls: - if x in inputs_ls: + def _set_connectivity_metadata(self, args, kwargs, outputs): + # If the layer returns tensors from its inputs unmodified, + # we copy them to avoid loss of KerasHistory metadata. + flat_outputs = nest.flatten(outputs) + flat_inputs = nest.flatten((args, kwargs)) + inputs_set = object_identity.ObjectIdentitySet(flat_inputs) + outputs_copy = [] + for x in flat_outputs: + if x in inputs_set: with backend.name_scope(self.name): x = array_ops.identity(x) - output_ls_copy.append(x) - outputs = nest.pack_sequence_as(outputs, output_ls_copy) + outputs_copy.append(x) + outputs = nest.pack_sequence_as(outputs, outputs_copy) - # Ignore `inputs` arg. - arguments = dict(zip(self._call_fn_args[1:], args)) - arguments.update(kwargs) - - # Add an inbound node to the layer, so it can keep track of this call. - # This updates the layer history of the output tensor(s). - self._add_inbound_node( - input_tensors=inputs, output_tensors=outputs, arguments=arguments) - return inputs, outputs - - def _add_inbound_node(self, - input_tensors, - output_tensors, - arguments=None): - """Internal method to create an inbound node for the layer. - - Arguments: - input_tensors: list of input tensors. - output_tensors: list of output tensors. - arguments: dictionary of keyword arguments that were passed to the - `call` method of the layer at the call that created the node. - """ - inbound_layers = nest.map_structure(lambda t: t._keras_history.layer, - input_tensors) - node_indices = nest.map_structure(lambda t: t._keras_history.node_index, - input_tensors) - tensor_indices = nest.map_structure(lambda t: t._keras_history.tensor_index, - input_tensors) - - # Create node, add it to inbound nodes. - node_module.Node( - self, - inbound_layers=inbound_layers, - node_indices=node_indices, - tensor_indices=tensor_indices, - input_tensors=input_tensors, - output_tensors=output_tensors, - arguments=arguments) - - # Update tensor history metadata. - # The metadata attribute consists of - # 1) a layer instance - # 2) a node index for the layer - # 3) a tensor index for the node. - # The allows layer reuse (multiple nodes per layer) and multi-output - # or multi-input layers (e.g. a layer can return multiple tensors, - # and each can be sent to a different layer). - for i, tensor in enumerate(nest.flatten(output_tensors)): - tensor._keras_history = KerasHistory(self, - len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access + # Create node, Node wires itself to inbound and outbound layers. + # The Node constructor actually updates this layer's self._inbound_nodes, + # sets _keras_history on the outputs, and adds itself to the + # `_outbound_nodes` of the layers that produced the inputs to this + # layer call. + node_module.Node(self, call_args=args, call_kwargs=kwargs, outputs=outputs) + return outputs def _get_node_attribute_at_index(self, node_index, attr, attr_name): """Private utility to retrieves an attribute (e.g. inputs) from a node. @@ -2706,6 +2697,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector): return all_args[1:] return all_args + @property + @tracking.cached_per_instance + def _call_fn_arg_positions(self): + call_fn_arg_positions = dict() + for pos, arg in enumerate(self._call_fn_args): + call_fn_arg_positions[arg] = pos + return call_fn_arg_positions + @property @tracking.cached_per_instance def _call_accepts_kwargs(self): @@ -2743,6 +2742,21 @@ class Layer(module.Module, version_utils.LayerVersionSelector): seen_weights.add(w) return output + def _split_out_first_arg(self, args, kwargs): + # Grab the argument corresponding to the first argument in the + # layer's `call` method spec. This will either be the first positional + # argument, or it will be provided as a keyword argument. + if args: + inputs = args[0] + args = args[1:] + elif self._call_fn_args[0] in kwargs: + kwargs = copy.copy(kwargs) + inputs = kwargs.pop(self._call_fn_args[0]) + else: + raise ValueError( + 'The first argument to `Layer.call` must always be passed.') + return inputs, args, kwargs + # SavedModel properties. Please see keras/saving/saved_model for details. @property @@ -2952,31 +2966,6 @@ class AddMetric(Layer): return config -class KerasHistory( - collections.namedtuple('KerasHistory', - ['layer', 'node_index', 'tensor_index'])): - """Tracks the Layer call that created a Tensor, for Keras Graph Networks. - - During construction of Keras Graph Networks, this metadata is added to - each Tensor produced as the output of a Layer, starting with an - `InputLayer`. This allows Keras to track how each Tensor was produced, and - this information is later retraced by the `keras.engine.Network` class to - reconstruct the Keras Graph Network. - - Attributes: - layer: The Layer that produced the Tensor. - node_index: The specific call to the Layer that produced this Tensor. Layers - can be called multiple times in order to share weights. A new node is - created every time a Layer is called. - tensor_index: The output index for this Tensor. Always zero if the Layer - that produced this Tensor only has one output. Nested structures of - Tensors are deterministically assigned an index via `nest.flatten`. - """ - # Added to maintain memory and performance characteristics of `namedtuple` - # while subclassing. - __slots__ = () - - # Avoid breaking users who directly import this symbol from this file. # TODO(fchollet): remove this. InputSpec = input_spec.InputSpec # pylint:disable=invalid-name diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index a0e1e9edc2f..30ac17d8270 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -254,8 +254,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers): op_layer = base_layer.TensorFlowOpLayer( node_def, constants=constants, name=name) created_layers.append(op_layer) - op_layer._add_inbound_node( # pylint: disable=protected-access - layer_inputs, op.outputs) + op_layer._set_connectivity_metadata( # pylint: disable=protected-access + args=(layer_inputs,), + kwargs={}, + outputs=op.outputs) processed_ops.update([op]) return processed_ops, created_layers diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 7b4ce8ad54c..9d6f5afd240 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -45,7 +45,6 @@ from tensorflow.python.keras import regularizers from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import input_spec -from tensorflow.python.keras.engine import node as node_module from tensorflow.python.keras.mixed_precision.experimental import autocast_variable from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer from tensorflow.python.keras.mixed_precision.experimental import policy @@ -698,16 +697,17 @@ class Layer(base_layer.Layer): mask_arg_passed_by_framework = True kwargs['mask'] = input_masks - # If `training` argument was not explicitly passed, propagate `training` - # value from this layer's calling layer. + # If `training` argument is None or not explicitly passed, + # propagate `training` value from this layer's calling layer. + training_value = None training_arg_passed_by_framework = False # Priority 1: `training` was explicitly passed. if self._call_arg_was_passed('training', args, kwargs): training_value = self._get_call_arg_value('training', args, kwargs) if not self._expects_training_arg: kwargs.pop('training') - else: - training_value = None + + if training_value is None: # Priority 2: `training` was passed to a parent layer. if call_context.training is not None: training_value = call_context.training @@ -727,7 +727,8 @@ class Layer(base_layer.Layer): training_value = math_ops.cast(training_value, dtypes.bool) else: training_value = bool(training_value) - kwargs['training'] = training_value + args, kwargs = self._set_call_arg_value( + 'training', training_value, args, kwargs) training_arg_passed_by_framework = True # Only create Keras history if at least one tensor originates from a @@ -798,11 +799,12 @@ class Layer(base_layer.Layer): '(layer: ' + self.name + ').') if base_layer_utils.have_all_keras_metadata(inputs): if training_arg_passed_by_framework: - kwargs.pop('training') + args, kwargs = self._set_call_arg_value( + 'training', None, args, kwargs, pop_kwarg_if_none=True) if mask_arg_passed_by_framework: kwargs.pop('mask') - inputs, outputs = self._set_connectivity_metadata_( - inputs, outputs, args, kwargs) + outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, + outputs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) if hasattr(self, '_set_inputs') and not self.inputs: @@ -2005,70 +2007,23 @@ class Layer(base_layer.Layer): args_dict = dict(zip(call_fn_args, args)) return args_dict[arg_name] - def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): - - # If the layer returns tensors from its inputs, unmodified, - # we copy them to avoid loss of tensor metadata. - output_ls = nest.flatten(outputs) - inputs_ls = object_identity.ObjectIdentitySet(nest.flatten(inputs)) - output_ls_copy = [] - for x in output_ls: - if x in inputs_ls: - with backend.name_scope(self.name): - x = array_ops.identity(x) - output_ls_copy.append(x) - outputs = nest.pack_sequence_as(outputs, output_ls_copy) - - # Ignore `inputs` arg. - arguments = dict(zip(self._call_fn_args[1:], args)) - arguments.update(kwargs) - - # Add an inbound node to the layer, so it can keep track of this call. - # This updates the layer history of the output tensor(s). - self._add_inbound_node( - input_tensors=inputs, output_tensors=outputs, arguments=arguments) - return inputs, outputs - - def _add_inbound_node(self, - input_tensors, - output_tensors, - arguments=None): - """Internal method to create an inbound node for the layer. - - Arguments: - input_tensors: list of input tensors. - output_tensors: list of output tensors. - arguments: dictionary of keyword arguments that were passed to the - `call` method of the layer at the call that created the node. - """ - inbound_layers = nest.map_structure(lambda t: t._keras_history.layer, - input_tensors) - node_indices = nest.map_structure(lambda t: t._keras_history.node_index, - input_tensors) - tensor_indices = nest.map_structure(lambda t: t._keras_history.tensor_index, - input_tensors) - - # Create node, add it to inbound nodes. - node_module.Node( - self, - inbound_layers=inbound_layers, - node_indices=node_indices, - tensor_indices=tensor_indices, - input_tensors=input_tensors, - output_tensors=output_tensors, - arguments=arguments) - - # Update tensor history metadata. - # The metadata attribute consists of - # 1) a layer instance - # 2) a node index for the layer - # 3) a tensor index for the node. - # The allows layer reuse (multiple nodes per layer) and multi-output - # or multi-input layers (e.g. a layer can return multiple tensors, - # and each can be sent to a different layer). - for i, tensor in enumerate(nest.flatten(output_tensors)): - tensor._keras_history = KerasHistory(self, - len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access + def _set_call_arg_value( + self, arg_name, new_value, args, + kwargs, inputs_in_args=False, pop_kwarg_if_none=False): + arg_pos = self._call_fn_arg_positions.get(arg_name, None) + if arg_pos is not None: + if not inputs_in_args: + # Ignore `inputs` arg. + arg_pos = arg_pos - 1 + if len(args) > arg_pos: + args = list(args) + args[arg_pos] = new_value + return args, kwargs + if new_value is None and pop_kwarg_if_none: + kwargs.pop(arg_name, None) + else: + kwargs[arg_name] = new_value + return args, kwargs def _get_node_attribute_at_index(self, node_index, attr, attr_name): """Private utility to retrieves an attribute (e.g. inputs) from a node. @@ -2395,6 +2350,14 @@ class Layer(base_layer.Layer): return all_args[1:] return all_args + @property + @tracking.cached_per_instance + def _call_fn_arg_positions(self): + call_fn_arg_positions = dict() + for pos, arg in enumerate(self._call_fn_args): + call_fn_arg_positions[arg] = pos + return call_fn_arg_positions + @property @tracking.cached_per_instance def _call_accepts_kwargs(self): diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index efa3de959f1..1664e927da7 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -164,17 +164,9 @@ class InputLayer(base_layer.Layer): self.is_placeholder = False self._batch_input_shape = tuple(input_tensor.shape.as_list()) - # Create an input node to add to self.outbound_node - # and set output_tensors' _keras_history. - input_tensor._keras_history = base_layer.KerasHistory(self, 0, 0) + # Create an input node. input_tensor._keras_mask = None - node_module.Node( - self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=[input_tensor], - output_tensors=[input_tensor]) + node_module.Node(layer=self, outputs=input_tensor) def get_config(self): config = { @@ -294,8 +286,8 @@ def Input( # pylint: disable=invalid-name # Return tensor including `_keras_history`. # Note that in this case train_output and test_output are the same pointer. - outputs = input_layer._inbound_nodes[0].output_tensors - if len(outputs) == 1: + outputs = input_layer._inbound_nodes[0].outputs + if isinstance(outputs, list) and len(outputs) == 1: return outputs[0] else: return outputs diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 770f0468881..807576cb45b 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -25,13 +25,11 @@ import itertools import json import os -import numpy as np import six from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context from tensorflow.python.framework import composite_tensor -from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl from tensorflow.python.framework import func_graph @@ -42,7 +40,6 @@ from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import compile_utils from tensorflow.python.keras.engine import input_layer as input_layer_module -from tensorflow.python.keras.engine import node as node_module from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import save @@ -307,15 +304,6 @@ class Network(base_layer.Layer): self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) layer._attribute_sentinel.add_parent(self._attribute_sentinel) - # Create the node linking internal inputs to internal outputs. - node_module.Node( - outbound_layer=self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=self._nested_inputs, - output_tensors=self._nested_outputs) - # Build self.input_names and self.output_names. self._set_output_names() self.input_names = [] @@ -337,6 +325,82 @@ class Network(base_layer.Layer): self._compute_tensor_usage_count() self._set_save_spec(self._nested_inputs) + @property + def input(self): + """Retrieves the input tensor(s) of a layer. + + Only applicable if the layer has exactly one input, + i.e. if it is connected to one incoming layer. + + Returns: + Input tensor or list of input tensors. + + Raises: + RuntimeError: If called in Eager mode. + AttributeError: If no inbound nodes are found. + """ + if self._is_graph_network: + return self._nested_inputs + return super(Network, self).input + + @property + def input_shape(self): + """Retrieves the input shape(s) of a layer. + + Only applicable if the layer has exactly one input, + i.e. if it is connected to one incoming layer, or if all inputs + have the same shape. + + Returns: + Input shape, as an integer shape tuple + (or list of shape tuples, one tuple per input tensor). + + Raises: + AttributeError: if the layer has no defined input_shape. + RuntimeError: if called in Eager mode. + """ + if self._is_graph_network: + return nest.map_structure(backend.int_shape, self.input) + return super(Network, self).input_shape + + @property + def output(self): + """Retrieves the output tensor(s) of a layer. + + Only applicable if the layer has exactly one output, + i.e. if it is connected to one incoming layer. + + Returns: + Output tensor or list of output tensors. + + Raises: + AttributeError: if the layer is connected to more than one incoming + layers. + RuntimeError: if called in Eager mode. + """ + if self._is_graph_network: + return self._nested_outputs + return super(Network, self).output + + @property + def output_shape(self): + """Retrieves the output shape(s) of a layer. + + Only applicable if the layer has one output, + or if all outputs have the same shape. + + Returns: + Output shape, as an integer shape tuple + (or list of shape tuples, one tuple per output tensor). + + Raises: + AttributeError: if the layer has no defined output shape. + RuntimeError: if called in Eager mode. + """ + if self._is_graph_network: + return nest.map_structure(backend.int_shape, self.output) + return super(Network, self).output_shape + def _set_output_names(self): """Assigns unique names to the Network's outputs. @@ -700,8 +764,7 @@ class Network(base_layer.Layer): ' implement a `call` method.') return self._run_internal_graph( - inputs, training=training, mask=mask, - convert_kwargs_to_constants=base_layer_utils.call_context().saving) + inputs, training=training, mask=mask) def compute_output_shape(self, input_shape): if not self._is_graph_network: @@ -741,20 +804,20 @@ class Network(base_layer.Layer): for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer + layer = node.layer if layer in self._input_layers: # We've already covered the input layers # a few lines above. continue - # Potentially redundant list, - # same size as node.input_tensors. + # Get the input shapes for the first argument of the node layer_input_shapes = [] - for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound(): - input_layer_key = inbound_layer.name + '_%s_%s' % (node_id, - tensor_id) + layer_inputs = node.call_args[0] + for layer_input in nest.flatten(layer_inputs): + kh = layer_input._keras_history + input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index, + kh.tensor_index) layer_input_shapes.append(layers_to_output_shapes[input_layer_key]) - layer_input_shapes = nest.pack_sequence_as(node.inbound_layers, + layer_input_shapes = nest.pack_sequence_as(layer_inputs, layer_input_shapes) # Layers expect shapes to be tuples for `compute_output_shape`. layer_input_shapes = tf_utils.convert_shapes( @@ -782,8 +845,7 @@ class Network(base_layer.Layer): # Return shapes as TensorShapes. return output_shapes - def _run_internal_graph(self, inputs, training=None, mask=None, - convert_kwargs_to_constants=False): + def _run_internal_graph(self, inputs, training=None, mask=None): """Computes output tensors for new inputs. # Note: @@ -793,21 +855,10 @@ class Network(base_layer.Layer): inputs: Tensor or nested structure of Tensors. training: Boolean learning phase. mask: (Optional) Tensor or nested structure of Tensors. - convert_kwargs_to_constants: Whether to convert Tensor kwargs to - constants. This is used when tracing the model call function during - saving to ensure that external tensors aren't captured. Returns: Two lists: output_tensors, output_masks """ - # Note: masking support is relevant mainly for Keras. - # It cannot be factored out without having the fully reimplement the network - # calling logic on the Keras side. We choose to incorporate it in - # Network because 1) it may be useful to fully support in tf.layers in - # the future and 2) Keras is a major user of Network. If you don't - # use masking, it does not interfere with regular behavior at all and you - # can ignore it. - inputs = self._flatten_to_reference_inputs(inputs) if mask is None: masks = [None for _ in range(len(inputs))] @@ -825,58 +876,26 @@ class Network(base_layer.Layer): depth_keys = list(self._nodes_by_depth.keys()) depth_keys.sort(reverse=True) - # Ignore the InputLayers when computing the graph. - depth_keys = depth_keys[1:] for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer + if node.is_input: + continue # Input tensors already exist. - if all( + if not all( str(id(tensor)) in tensor_dict - for tensor in nest.flatten(node.input_tensors)): + for tensor in nest.flatten(node.keras_inputs)): + continue # Node is not computable, try skipping. - # Call layer (reapplying ops to new inputs). - computed_tensors = nest.map_structure( - lambda t: tensor_dict[str(id(t))].pop(), node.input_tensors) + layer = node.layer + args, kwargs = node.map_arguments(tensor_dict) + outputs = layer(*args, **kwargs) - # Ensure `training` arg propagation if applicable. - kwargs = copy.copy(node.arguments) if node.arguments else {} - if convert_kwargs_to_constants: - kwargs = _map_tensors_to_constants(kwargs) - - argspec = self._layer_call_argspecs[layer].args - if 'training' in argspec: - if 'training' not in kwargs or kwargs['training'] is None: - kwargs['training'] = training - elif (type(kwargs['training']) is ops.Tensor and # pylint: disable=unidiomatic-typecheck - any([ - kwargs['training'] is x - for x in backend._GRAPH_LEARNING_PHASES.values() - ])): - kwargs['training'] = training # Materialize placeholder. - - # Map Keras tensors in kwargs to their computed value. - def _map_tensor_if_from_keras_layer(t): - if (isinstance(t, - (ops.Tensor, composite_tensor.CompositeTensor)) and - hasattr(t, '_keras_history')): - t_id = str(id(t)) - return tensor_dict[t_id].pop() - return t - - kwargs = nest.map_structure(_map_tensor_if_from_keras_layer, kwargs) - - # Compute outputs. - output_tensors = layer(computed_tensors, **kwargs) - - # Update tensor_dict. - for x, y in zip( - nest.flatten(node.output_tensors), nest.flatten(output_tensors)): - x_id = str(id(x)) - tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] + # Update tensor_dict. + for x, y in zip(nest.flatten(node.outputs), nest.flatten(outputs)): + x_id = str(id(x)) + tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] output_tensors = [] output_shapes = [] @@ -1367,7 +1386,7 @@ class Network(base_layer.Layer): # pylint: disable=protected-access layer = x._keras_history.layer if len(layer._inbound_nodes) > 1 or ( - layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): + layer._inbound_nodes and not layer._inbound_nodes[0].is_input): cls_name = self.__class__.__name__ logging.warning(cls_name + ' inputs must come from ' '`tf.keras.Input` (thus holding past layer metadata), ' @@ -1437,7 +1456,7 @@ class Network(base_layer.Layer): def _get_min_depth(node): """Gets the minimum depth at which node can be computed.""" min_depth = 0 - for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True): + for layer, node_id, _, _ in node.iterate_inbound(): inbound_node = layer._inbound_nodes[node_id] if inbound_node in node_to_depth: min_depth = min(min_depth, node_to_depth[inbound_node]) @@ -1465,8 +1484,8 @@ class Network(base_layer.Layer): if depth is None: # Defer until inbound nodes are processed. unprocessed_nodes.append(node) continue - node_key = _make_node_key(node.outbound_layer.name, - node.outbound_layer._inbound_nodes.index(node)) + node_key = _make_node_key(node.layer.name, + node.layer._inbound_nodes.index(node)) if node_key not in self._network_nodes: node_to_depth[node] = depth self._network_nodes.add(node_key) @@ -1506,21 +1525,13 @@ class Network(base_layer.Layer): for depth in depth_keys: for node in self._nodes_by_depth[depth]: input_tensors = { - str(id(tensor)) for tensor in nest.flatten(node.input_tensors) + str(id(tensor)) for tensor in nest.flatten(node.keras_inputs) } if input_tensors.issubset(available_tensors): - kwargs = copy.copy(node.arguments) if node.arguments else {} - - for tensor in nest.flatten(kwargs): - if (isinstance(tensor, - (ops.Tensor, composite_tensor.CompositeTensor)) and - hasattr(tensor, '_keras_history')): - tensor_usage_count[str(id(tensor))] += 1 - - for tensor in nest.flatten(node.input_tensors): + for tensor in nest.flatten(node.keras_inputs): tensor_usage_count[str(id(tensor))] += 1 - for output_tensor in nest.flatten(node.output_tensors): + for output_tensor in nest.flatten(node.outputs): available_tensors.add(str(id(output_tensor))) for tensor in self.outputs: @@ -1631,97 +1642,35 @@ def _map_graph_network(inputs, outputs): Raises: ValueError: In case the network is not valid (e.g. disconnected graph). """ - # Network_nodes: set of nodes included in the graph of layers - # (not all nodes included in the layers are relevant to the current graph). - network_nodes = set() # ids of all nodes relevant to the Network + # "depth" is number of layers between output Node and the Node. + # Nodes are ordered from inputs -> outputs. + nodes_in_decreasing_depth, layer_indices = _build_map(outputs) + network_nodes = { + _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node)) + for node in nodes_in_decreasing_depth + } + nodes_depths = {} # dict {node: depth value} layers_depths = {} # dict {layer: depth value} - layer_indices = {} # dict {layer: index in traversal} - nodes_in_decreasing_depth = [] - - def build_map(tensor, - finished_nodes, - nodes_in_progress, - layer, - node_index, - tensor_index): - """Builds a map of the graph of layers. - - This recursively updates the map `layer_indices`, - the list `nodes_in_decreasing_depth` and the set `network_nodes`. - - Arguments: - tensor: Some tensor in a graph. - finished_nodes: Set of nodes whose subgraphs have been traversed - completely. Useful to prevent duplicated work. - nodes_in_progress: Set of nodes that are currently active on the - recursion stack. Useful to detect cycles. - layer: Layer from which `tensor` comes from. If not provided, - will be obtained from `tensor._keras_history`. - node_index: Node index from which `tensor` comes from. - tensor_index: Tensor_index from which `tensor` comes from. - - Raises: - ValueError: if a cycle is detected. - """ - node = layer._inbound_nodes[node_index] # pylint: disable=protected-access - - # Prevent cycles. - if node in nodes_in_progress: - raise ValueError('The tensor ' + str(tensor) + ' at layer "' + - layer.name + '" is part of a cycle.') - - # Don't repeat work for shared subgraphs - if node in finished_nodes: - return - - node_key = _make_node_key(layer.name, node_index) - # Update network_nodes. - network_nodes.add(node_key) - - # Store the traversal order for layer sorting. - if layer not in layer_indices: - layer_indices[layer] = len(layer_indices) - - nodes_in_progress.add(node) - - # Propagate to all previous tensors connected to this node. - for layer, node_index, tensor_index, tensor in node.iterate_inbound( - include_arguments=True): - build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, - tensor_index) - - finished_nodes.add(node) - nodes_in_progress.remove(node) - nodes_in_decreasing_depth.append(node) - - finished_nodes = set() - nodes_in_progress = set() - for x in outputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - build_map(x, finished_nodes, nodes_in_progress, - layer=layer, - node_index=node_index, - tensor_index=tensor_index) for node in reversed(nodes_in_decreasing_depth): # If the depth is not set, the node has no outbound nodes (depth 0). depth = nodes_depths.setdefault(node, 0) # Update the depth of the corresponding layer - previous_depth = layers_depths.get(node.outbound_layer, 0) + previous_depth = layers_depths.get(node.layer, 0) # If we've seen this layer before at a higher depth, # we should use that depth instead of the node depth. # This is necessary for shared layers that have inputs at different # depth levels in the graph. depth = max(depth, previous_depth) - layers_depths[node.outbound_layer] = depth + layers_depths[node.layer] = depth nodes_depths[node] = depth # Update the depth of inbound nodes. # The "depth" of a node is the max of the depths # of all nodes it is connected to + 1. - for node_dep in node._get_all_node_dependencies(): + for node_dep in node.parent_nodes: previous_depth = nodes_depths.get(node_dep, 0) nodes_depths[node_dep] = max(depth + 1, previous_depth) @@ -1773,9 +1722,9 @@ def _map_graph_network(inputs, outputs): layers_with_complete_input = [] # To provide a better error msg. for depth in depth_keys: for node in nodes_by_depth[depth]: - layer = node.outbound_layer - if layer: - for x in nest.flatten(node.input_tensors): + layer = node.layer + if layer and not node.is_input: + for x in nest.flatten(node.keras_inputs): if id(x) not in computable_tensors: raise ValueError('Graph disconnected: ' 'cannot obtain value for tensor ' + str(x) + @@ -1783,7 +1732,7 @@ def _map_graph_network(inputs, outputs): 'The following previous layers ' 'were accessed without issue: ' + str(layers_with_complete_input)) - for x in nest.flatten(node.output_tensors): + for x in nest.flatten(node.outputs): computable_tensors.add(id(x)) layers_with_complete_input.append(layer.name) @@ -1798,6 +1747,68 @@ def _map_graph_network(inputs, outputs): return network_nodes, nodes_by_depth, layers, layers_by_depth +def _build_map(outputs): + """This method topologically sorts nodes in order from inputs to outputs. + + It uses a depth-first search to topologically sort nodes that appear in the + _keras_history connectivity metadata of `outputs`. + + Args: + outputs: the output tensors whose _keras_history metadata should be walked. + This may be an arbitrary nested structure. + + Returns: + A tuple like (ordered_nodes, layer_to_first_traversal_index) + ordered_nodes: list of nodes appearing in the keras history, topologically + sorted from original inputs to the `outputs`. + (If outputs have different sets of ancestors, the inputs to one output + may appear after a different output). + layer_to_first_traversal_index: + A dict mapping layer to the traversal index in the DFS where it is + seen. Note: if a layer is shared by several nodes, the dict will only + store the index corresponding to the *first* time the layer seen. + """ + finished_nodes = set() + nodes_in_progress = set() + nodes_in_decreasing_depth = [] # nodes from inputs -> outputs. + layer_indices = {} # layer -> in traversal order. + for output in nest.flatten(outputs): + _build_map_helper(output, finished_nodes, nodes_in_progress, + nodes_in_decreasing_depth, layer_indices) + return nodes_in_decreasing_depth, layer_indices + + +def _build_map_helper(tensor, finished_nodes, nodes_in_progress, + nodes_in_decreasing_depth, layer_indices): + """Recursive helper for `_build_map`.""" + layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access + node = layer._inbound_nodes[node_index] # pylint: disable=protected-access + + # Don't repeat work for shared subgraphs + if node in finished_nodes: + return + + # Prevent cycles. + if node in nodes_in_progress: + raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name + + '" is part of a cycle.') + + # Store the traversal order for layer sorting. + if layer not in layer_indices: + layer_indices[layer] = len(layer_indices) + + # Propagate to all previous tensors connected to this node. + nodes_in_progress.add(node) + if not node.is_input: + for tensor in node.keras_inputs: + _build_map_helper(tensor, finished_nodes, nodes_in_progress, + nodes_in_decreasing_depth, layer_indices) + + finished_nodes.add(node) + nodes_in_progress.remove(node) + nodes_in_decreasing_depth.append(node) + + def _map_subgraph_network(inputs, outputs): """Returns the nodes and layers in the topology from `inputs` to `outputs`. @@ -1820,36 +1831,6 @@ def _should_skip_first_node(layer): return issubclass(layer.__class__, Network) and layer._is_graph_network -def _serialize_tensors(kwargs): - """Serializes Tensors passed to `call`.""" - - def _serialize_keras_tensor(t): - """Serializes a single Tensor passed to `call`.""" - if hasattr(t, '_keras_history'): - kh = t._keras_history - return [kh.layer.name, kh.node_index, kh.tensor_index] - - if isinstance(t, np.ndarray): - return t.tolist() - - if isinstance(t, ops.Tensor): - return backend.get_value(t).tolist() - - return t - - return nest.map_structure(_serialize_keras_tensor, kwargs) - - -def _map_tensors_to_constants(kwargs): - - def _map_to_constants(t): - if not hasattr(t, '_keras_history') and isinstance(t, ops.Tensor): - return constant_op.constant(backend.get_value(t)) - return t - - return nest.map_structure(_map_to_constants, kwargs) - - def _deserialize_keras_tensors(kwargs, layer_map): """Deserializes Keras Tensors passed to `call`..""" @@ -1863,7 +1844,7 @@ def _deserialize_keras_tensors(kwargs, layer_map): layer = layer_map[layer_name] node = layer._inbound_nodes[node_index] - return nest.flatten(node.output_tensors)[tensor_index] + return nest.flatten(node.outputs)[tensor_index] return t kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True) @@ -1961,7 +1942,7 @@ def reconstruct_from_config(config, custom_objects=None, created_layers=None): return inbound_node = inbound_layer._inbound_nodes[inbound_node_index] input_tensors.append( - nest.flatten(inbound_node.output_tensors)[inbound_tensor_index]) + nest.flatten(inbound_node.outputs)[inbound_tensor_index]) input_tensors = nest.pack_sequence_as(node_data, input_tensors) # Call layer on its inputs, thus creating the node # and building the layer if needed. @@ -2077,37 +2058,11 @@ def get_network_config(network, serialize_layer_fn=None): filtered_inbound_nodes = [] for original_node_index, node in enumerate(layer._inbound_nodes): node_key = _make_node_key(layer.name, original_node_index) - if node_key in network._network_nodes: + if node_key in network._network_nodes and not node.is_input: # The node is relevant to the model: # add to filtered_inbound_nodes. - if node.arguments: - kwargs = _serialize_tensors(node.arguments) - try: - json.dumps(kwargs) - except TypeError: - logging.warning( - 'Layer ' + layer.name + - ' was passed non-serializable keyword arguments: ' + - str(node.arguments) + '. They will not be included ' - 'in the serialized model (and thus will be missing ' - 'at deserialization time).') - kwargs = {} - else: - kwargs = {} - if node.inbound_layers: - node_data = [] - for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound(): - node_key = _make_node_key(inbound_layer.name, node_id) - new_node_index = node_conversion_map.get(node_key, 0) - node_data.append( - tf_utils.ListWrapper( - [inbound_layer.name, new_node_index, tensor_id, kwargs])) - node_data = nest.pack_sequence_as(node.input_tensors, node_data) - if not nest.is_sequence(node_data): - node_data = [node_data] - # Convert ListWrapper to list for backwards compatible configs. - node_data = tf_utils.convert_inner_node_data(node_data) - filtered_inbound_nodes.append(node_data) + node_data = node.serialize(_make_node_key, node_conversion_map) + filtered_inbound_nodes.append(node_data) layer_config = serialize_layer_fn(layer) layer_config['name'] = layer.name diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py index 429a096037f..7c19a3ae2bd 100644 --- a/tensorflow/python/keras/engine/network_test.py +++ b/tensorflow/python/keras/engine/network_test.py @@ -1785,6 +1785,7 @@ class AttrTrackingLayer(base_layer.Layer): return super(AttrTrackingLayer, self).dynamic +@combinations.generate(combinations.combine(mode=['graph', 'eager'])) class CacheCorrectnessTest(keras_parameterized.TestCase): def layer_and_network_test(self): @@ -1919,8 +1920,12 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): class MyLayer(base_layer.Layer): def call(self, x, training=None): - self.training = training - return x + if training is None: + return x * -1.0 + elif training: + return x + else: + return x * 0.0 my_layer = MyLayer() x = np.ones((1, 10)) @@ -1929,9 +1934,8 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): outputs = my_layer(inputs, training=True) network = network_lib.Network(inputs, outputs) - network(x, training=False) # Hard-coded value passed during construction is respected. - self.assertTrue(my_layer.training) + self.assertAllEqual(network(x, training=False), x) inputs = input_layer_lib.Input(10) outputs = my_layer(inputs, training=False) @@ -1939,19 +1943,16 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): network(x, training=True) # Hard-coded value passed during construction is respected. - self.assertFalse(my_layer.training) + self.assertAllEqual(network(x, training=True), x * 0.0) inputs = input_layer_lib.Input(10) outputs = my_layer(inputs, training=None) network = network_lib.Network(inputs, outputs) - network(x, training=True) # `None` value passed during construction is overridden. - self.assertTrue(my_layer.training) - network(x, training=False) + self.assertAllEqual(network(x, training=True), x) # `None` value passed during construction is overridden. - self.assertFalse(my_layer.training) - + self.assertAllEqual(network(x, training=False), x * 0.0) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py index c8255764efc..945cf1c64bd 100644 --- a/tensorflow/python/keras/engine/node.py +++ b/tensorflow/python/keras/engine/node.py @@ -18,10 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import copy +import json +import numpy as np + from tensorflow.python.framework import ops from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import serialization class Node(object): @@ -33,163 +41,227 @@ class Node(object): a node is added to `layer._outbound_nodes`. Arguments: - outbound_layer: the layer that takes - `input_tensors` and turns them into `output_tensors` - (the node gets created when the `call` - method of the layer was called). - inbound_layers: a list of layers, the same length as `input_tensors`, - the layers from where `input_tensors` originate. - node_indices: a list of integers, the same length as `inbound_layers`. - `node_indices[i]` is the origin node of `input_tensors[i]` - (necessary since each inbound layer might have several nodes, - e.g. if the layer is being shared with a different data stream). - tensor_indices: a list of integers, - the same length as `inbound_layers`. - `tensor_indices[i]` is the index of `input_tensors[i]` within the - output of the inbound layer - (necessary since each inbound layer might - have multiple tensor outputs, with each one being - independently manipulable). - input_tensors: list of input tensors. - output_tensors: list of output tensors. - arguments: dictionary of keyword arguments that were passed to the - `call` method of the layer at the call that created the node. - - `node_indices` and `tensor_indices` are basically fine-grained coordinates - describing the origin of the `input_tensors`. - - A node from layer A to layer B is added to: - - A._outbound_nodes - - B._inbound_nodes + layer: The Layer for the Layer.__call__ this node represents. + call_args: The positional arguments the Layer was called with. + call_kwargs: The keyword arguments the Layer was called with. + outputs: The outputs of the Layer.__call__ """ def __init__(self, - outbound_layer, - inbound_layers, - node_indices, - tensor_indices, - input_tensors, - output_tensors, - arguments=None): - # Layer instance (NOT a sequence) - if isinstance(outbound_layer, (list, tuple, dict)): - raise ValueError('`outbound_layer` should be a layer instance, ' - 'not a list, tuple, or, dict.') + layer, + call_args=None, + call_kwargs=None, + outputs=None): + call_args = [] if call_args is None else call_args + call_kwargs = {} if call_kwargs is None else call_kwargs + outputs = [] if outputs is None else outputs - # These arguments are user-provided. Copy them here so that future - # user modifications do not affect the node's metadata. - input_tensors = nest.map_structure(lambda t: t, input_tensors) - output_tensors = nest.map_structure(lambda t: t, output_tensors) - arguments = nest.map_structure(lambda t: t, arguments) + self.layer = layer + self.is_input = not call_args and not call_kwargs - # this is the layer that takes a nested structure of input tensors - # and turns them into a nested structure of output tensors. - # the current node will be added to - # the inbound_nodes of outbound_layer. - self.outbound_layer = outbound_layer + # These arguments are user-provided. Copy the structures here so that + # future user modifications do not affect the node's metadata. + # We copy using map_structure rather than python's shallow or deep copy, + # because the args can be data structures (so shallow copy is + # insufficient), but individual values might not support copy.copy + # or be too expensive to deep copy. + call_args = nest.map_structure(lambda t: t, call_args) + call_kwargs = nest.map_structure(lambda t: t, call_kwargs) + self.outputs = nest.map_structure(lambda t: t, outputs) + self.call_args = call_args + self.call_kwargs = call_kwargs - # The following 3 properties describe where - # the input tensors come from: which layers, - # and for each layer, which node and which - # tensor output of each node. + # Cached for performance. + self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs)) - # Nested structure of layer instances. - self.inbound_layers = inbound_layers - # Nested structure of integers, 1:1 mapping with inbound_layers. - self.node_indices = node_indices - # Nested of integers, 1:1 mapping with inbound_layers. - self.tensor_indices = tensor_indices + # Create TensorFlowOpLayers if needed. + for obj in self._flat_arguments: + if (isinstance(obj, ops.Tensor) and + base_layer_utils.needs_keras_history(obj, ignore_call_context=True)): + base_layer_utils.create_keras_history(obj) - # Following 2 properties: - # tensor inputs and outputs of outbound_layer. + self._keras_inputs = [] + self._keras_inputs_ids_and_indices = [] + for i, ele in enumerate(self._flat_arguments): + if is_keras_tensor(ele): + self._keras_inputs.append(ele) + kt_id = str(id(ele)) + kt_index = i + self._keras_inputs_ids_and_indices.append((kt_id, kt_index)) - # Nested structure of tensors. 1:1 mapping with inbound_layers. - self.input_tensors = input_tensors - # Nested structure of tensors, created by outbound_layer.call(). - self.output_tensors = output_tensors + # Wire up Node to Layers. + self.layer._inbound_nodes.append(self) + for kt in self.keras_inputs: + inbound_layer = kt._keras_history.layer + if inbound_layer is not None: # `None` for `Input` tensors. + inbound_layer._outbound_nodes.append(self) - # Following 2 properties: input and output shapes. + # Set metadata on outputs. + node_index = len(self.layer._inbound_nodes) - 1 + for i, tensor in enumerate(nest.flatten(outputs)): + tensor._keras_history = KerasHistory( + layer=layer, node_index=node_index, tensor_index=i) - # Nested structure of shape tuples, shapes of input_tensors. - self.input_shapes = nest.map_structure(backend.int_shape, input_tensors) - # Nested structure of shape tuples, shapes of output_tensors. - self.output_shapes = nest.map_structure(backend.int_shape, output_tensors) + @property + def keras_inputs(self): + """Tensors input to this node that can be traced back to a `keras.Input`.""" + return self._keras_inputs - # Optional keyword arguments to layer's `call`. - self.arguments = arguments - - # Create Keras History for any Keras Tensors in `arguments`. - tensor_arguments = [ - t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor) - ] - for tensor_argument in tensor_arguments: - if base_layer_utils.needs_keras_history( - tensor_argument, ignore_call_context=True): - base_layer_utils.create_keras_history(tensor_argument) - - # Add nodes to all layers involved. - for layer in nest.flatten(inbound_layers): - if layer is not None: - # For compatibility with external Keras, we use the deprecated - # accessor here. - layer.outbound_nodes.append(self) - # For compatibility with external Keras, we use the deprecated - # accessor here. - outbound_layer.inbound_nodes.append(self) - - def iterate_inbound(self, include_arguments=False): - """Returns a list of tuples representing the inbound data. - - Arguments: - include_arguments: Whether to also iterate over any Keras Tensors - passed as args, kwargs. - - Returns: - List of tuples like: (inbound_layer, node_index, tensor_index, tensor). - """ - inputs_inbound = list( - zip( - nest.flatten(self.inbound_layers), - nest.flatten(self.node_indices), - nest.flatten(self.tensor_indices), - nest.flatten(self.input_tensors))) - - if include_arguments: - keras_tensor_arguments = [ - kt for kt in nest.flatten(self.arguments) - if hasattr(kt, '_keras_history') - ] - - def _get_inbound(keras_tensor): - kh = keras_tensor._keras_history - return kh.layer, kh.node_index, kh.tensor_index, keras_tensor - - arguments_inbound = nest.map_structure(_get_inbound, - keras_tensor_arguments) - - return inputs_inbound + arguments_inbound - else: - return inputs_inbound - - def _get_all_node_dependencies(self): - """Returns all of the nodes this node immediately depends on.""" + @property + def parent_nodes(self): + """Returns all the `Node`s whose output this node immediately depends on.""" node_deps = [] - for layer, node_index, _, _ in self.iterate_inbound(): - node_deps.append(layer._inbound_nodes[node_index]) - - for arg in nest.flatten(self.arguments): - if isinstance(arg, ops.Tensor) and hasattr(arg, '_keras_history'): - kh = arg._keras_history - node_deps.append(kh.layer._inbound_nodes[kh.node_index]) - + for kt in self.keras_inputs: + layer = kt._keras_history.layer + node_index = kt._keras_history.node_index + if layer is not None: # `None` for `Input` tensors. + node_deps.append(layer._inbound_nodes[node_index]) return node_deps - def get_config(self): - inbound_names = nest.map_structure( - lambda layer: layer.name if layer else None, self.inbound_layers) - return { - 'outbound_layer': self.outbound_layer.name, - 'inbound_layers': inbound_names, - 'node_indices': self.node_indices, - 'tensor_indices': self.tensor_indices - } + def iterate_inbound(self): + """Yields tuples representing the data inbound from other nodes. + + Yields: + tuples like: (inbound_layer, node_index, tensor_index, tensor). + """ + for kt in self.keras_inputs: + keras_history = kt._keras_history + layer = keras_history.layer + node_index = keras_history.node_index + tensor_index = keras_history.tensor_index + yield layer, node_index, tensor_index, kt + + def map_arguments(self, tensor_dict): + """Maps Keras Tensors to computed Tensors using `tensor_dict`.""" + flat_arguments = copy.copy(self._flat_arguments) + for kt_id, kt_index in self._keras_inputs_ids_and_indices: + flat_arguments[kt_index] = tensor_dict[kt_id].pop() + + args, kwargs = nest.pack_sequence_as( + (self.call_args, self.call_kwargs), flat_arguments) + return args, kwargs + + def serialize(self, make_node_key, node_conversion_map): + """Serializes `Node` for Functional API's `get_config`.""" + # Serialization still special-cases first argument. + args, kwargs = self.call_args, self.call_kwargs + inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs) + + # Treat everything other than first argument as a kwarg. + arguments = dict(zip(self.layer._call_fn_args[1:], args)) + arguments.update(kwargs) + kwargs = arguments + + kwargs = nest.map_structure(_serialize_keras_tensor, kwargs) + try: + json.dumps(kwargs, default=serialization.get_json_type) + except TypeError: + kwarg_types = nest.map_structure(type, kwargs) + logging.warning('Layer ' + self.layer.name + + ' was passed non-JSON-serializable arguments. ' + + 'Arguments had types: ' + + str(kwarg_types) + '. They will not be included ' + 'in the serialized model (and thus will be missing ' + 'at deserialization time).') + kwargs = {} + + # `kwargs` is added to each Tensor in the first arg. This should be + # changed in a future version of the serialization format. + def serialize_first_arg_tensor(t): + kh = t._keras_history + node_index = kh.node_index + node_key = make_node_key(kh.layer.name, node_index) + new_node_index = node_conversion_map.get(node_key, 0) + data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs] + return tf_utils.ListWrapper(data) + + data = nest.map_structure(serialize_first_arg_tensor, inputs) + if not nest.is_sequence(data): + data = [data] + data = tf_utils.convert_inner_node_data(data) + return data + + ############################################################# + # Properties for Backwards compatibility. + # These only check the first input argument + # As nodes are internal, they may be removed in the future. + ############################################################# + + @property + def input_tensors(self): + if self.is_input: + return [self.outputs] # Used in `Layer.input`. + return self.call_args[0] + + @property + def output_tensors(self): + if self.is_input: + return [self.outputs] # Used in `Layer.input`. + return self.outputs + + @property + def input_shapes(self): + input_shapes = nest.map_structure(backend.int_shape, self.input_tensors) + if len(input_shapes) == 1 and not self.is_input: + return input_shapes[0] + return input_shapes + + @property + def output_shapes(self): + return nest.map_structure(backend.int_shape, self.output_tensors) + + @property + def outbound_layer(self): + return self.layer + + @property + def inbound_layers(self): + if self.is_input: + return [] + inbound_layers = nest.map_structure(lambda t: t._keras_history.layer, + self.call_args[0]) + return inbound_layers + + +class KerasHistory( + collections.namedtuple('KerasHistory', + ['layer', 'node_index', 'tensor_index'])): + """Tracks the Layer call that created a Tensor, for Keras Graph Networks. + + During construction of Keras Graph Networks, this metadata is added to + each Tensor produced as the output of a Layer, starting with an + `InputLayer`. This allows Keras to track how each Tensor was produced, and + this information is later retraced by the `keras.engine.Network` class to + reconstruct the Keras Graph Network. + + Attributes: + layer: The Layer that produced the Tensor. + node_index: The specific call to the Layer that produced this Tensor. Layers + can be called multiple times in order to share weights. A new node is + created every time a Layer is called. + tensor_index: The output index for this Tensor. Always zero if the Layer + that produced this Tensor only has one output. Nested structures of + Tensors are deterministically assigned an index via `nest.flatten`. + """ + # Added to maintain memory and performance characteristics of `namedtuple` + # while subclassing. + __slots__ = () + + +def is_keras_tensor(obj): + return hasattr(obj, '_keras_history') + + +def _serialize_keras_tensor(t): + """Serializes a single Tensor passed to `call`.""" + if hasattr(t, '_keras_history'): + kh = t._keras_history + return [kh.layer.name, kh.node_index, kh.tensor_index] + + if isinstance(t, np.ndarray): + return t.tolist() + + if isinstance(t, ops.Tensor): + return backend.get_value(t).tolist() + + return t diff --git a/tensorflow/python/keras/engine/node_test.py b/tensorflow/python/keras/engine/node_test.py new file mode 100644 index 00000000000..80c5144da1b --- /dev/null +++ b/tensorflow/python/keras/engine/node_test.py @@ -0,0 +1,160 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#,============================================================================ +"""Tests for layer graphs construction & handling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import node as node_module +from tensorflow.python.platform import test + + +class DummyTensor(object): + + def __init__(self, shape=None): + self.shape = shape + + +class DummyLayer(base_layer.Layer): + pass + + +class NetworkConstructionTest(keras_parameterized.TestCase): + + def test_chained_node_construction(self): + # test basics + a = DummyTensor(shape=(None, 32)) + b = DummyTensor(shape=(None, 32)) + + a_layer = DummyLayer() + node = node_module.Node(a_layer, outputs=a) + self.assertEqual(node.outbound_layer, a_layer) + + self.assertTrue(node.is_input) + self.assertListEqual(node.inbound_layers, []) + self.assertListEqual(node.input_tensors, [a]) + self.assertListEqual(node.input_shapes, [(None, 32)]) + self.assertListEqual(node.output_tensors, [a]) + self.assertListEqual(node.output_shapes, [(None, 32)]) + + b_layer = DummyLayer() + node_module.Node(b_layer, outputs=b) + + dense = DummyLayer() + a_2 = DummyTensor() + node_a = node_module.Node(layer=dense, call_args=(a,), outputs=a_2) + b_2 = DummyTensor() + node_b = node_module.Node(layer=dense, call_args=(b,), outputs=b_2) + + # test the node attributes + self.assertFalse(node_a.is_input) + self.assertFalse(node_b.is_input) + self.assertEqual(node_a.call_args, (a,)) + self.assertEqual(node_a.call_kwargs, {}) + self.assertEqual(node_a.outputs, a_2) + + # Test the layer wiring + self.assertLen(dense._inbound_nodes, 2) + self.assertLen(dense._outbound_nodes, 0) + self.assertEqual(dense._inbound_nodes, [node_a, node_b]) + self.assertEqual(dense._inbound_nodes[0].inbound_layers, a_layer) + self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense) + self.assertEqual(dense._inbound_nodes[1].inbound_layers, b_layer) + self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense) + self.assertIs(dense._inbound_nodes[0].input_tensors, a) + self.assertIs(dense._inbound_nodes[1].input_tensors, b) + + def test_multi_input_node(self): + # test multi-input layer + a = DummyTensor() + b = DummyTensor() + + dense = DummyLayer() + a_2 = DummyTensor() + node_module.Node(layer=dense, call_args=(a,), outputs=a_2) + b_2 = DummyTensor() + node_module.Node(layer=dense, call_args=(b,), outputs=b_2) + + concat_layer = DummyLayer() + merged = DummyTensor() + node_module.Node(layer=concat_layer, call_args=([a_2, b_2],), + outputs=merged) + + merge_layer, merge_node_index, merge_tensor_index = merged._keras_history + + self.assertEqual(merge_node_index, 0) + self.assertEqual(merge_tensor_index, 0) + + self.assertLen(merge_layer._inbound_nodes, 1) + self.assertLen(merge_layer._outbound_nodes, 0) + + self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2) + self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a_2, b_2]) + self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2) + + def test_arg_and_kwarg_mix(self): + input_layer = DummyLayer() + input_layer_2 = DummyLayer() + a = DummyTensor() + node_a = node_module.Node(layer=input_layer, outputs=a) + b = DummyTensor() + node_b = node_module.Node(layer=input_layer_2, outputs=b) + + arg_2 = DummyTensor() + arg_3 = DummyTensor() + node_c = node_module.Node(layer=input_layer, outputs=arg_3) + + kwarg_x = DummyTensor() + kwarg_y = DummyTensor() + node_d = node_module.Node(layer=input_layer, outputs=kwarg_y) + + merge_layer = DummyLayer() + merged = DummyTensor() + node = node_module.Node(layer=merge_layer, + call_args=([a, b], arg_2, arg_3), + call_kwargs={'x': kwarg_x, 'y': kwarg_y}, + outputs=merged) + + merge_layer, merge_node_index, merge_tensor_index = merged._keras_history + + # Check the saved call args/kwargs + self.assertEqual(([a, b], arg_2, arg_3), node.call_args) + self.assertEqual({'x': kwarg_x, 'y': kwarg_y}, node.call_kwargs) + + # Only the inputs that were produced by input nodes should appear in + # keras_tensors + self.assertEqual({a, b, arg_3, kwarg_y}, set(node.keras_inputs)) + self.assertEqual(set(node.parent_nodes), {node_a, node_b, node_c, node_d}) + + # Check the layer wirings + self.assertEqual(merge_node_index, 0) + self.assertEqual(merge_tensor_index, 0) + self.assertLen(merge_layer._inbound_nodes, 1) + self.assertLen(merge_layer._outbound_nodes, 0) + self.assertLen(input_layer._outbound_nodes, 3) + self.assertLen(input_layer_2._outbound_nodes, 1) + + # The 'backwards compatibility' attributes should only check the + # first call argument + self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2) + self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a, b]) + self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 9edfa4f958b..2d5abac7fd6 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -215,7 +215,7 @@ class Sequential(training.Model): set_inputs = True if set_inputs: - outputs = nest.flatten(layer._inbound_nodes[-1].output_tensors) + outputs = nest.flatten(layer._inbound_nodes[-1].outputs) if len(outputs) != 1: raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) self.outputs = outputs diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index 0cbb70109fc..a93bba0271b 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib - from absl.testing import parameterized import numpy as np @@ -302,11 +300,11 @@ class CorrectnessTest(keras_parameterized.TestCase): self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4) @parameterized.named_parameters([ - ('_None', contextlib.contextmanager(lambda: iter([None])), 0., 4.), - ('_0', lambda: keras.backend.learning_phase_scope(0), 4., 4.), - ('_1', lambda: keras.backend.learning_phase_scope(1), 0., 0.), + ('_None', None, 0., 4.), + ('_False', False, 4., 4.), + ('_True', True, 0., 0.), ]) - def test_nested_model_learning_phase(self, nested_scope_fn, + def test_nested_model_learning_phase(self, training, expected_training_loss, expected_validation_loss): """Tests that learning phase is correctly set in an intermediate layer.""" @@ -326,18 +324,17 @@ class CorrectnessTest(keras_parameterized.TestCase): return keras.Model(inputs, outputs) def _regularize_model(unregularized_model): - inputs = keras.Input(unregularized_model.inputs[0].shape[1:]) - with nested_scope_fn(): - logits = unregularized_model(inputs) - outputs = keras.activations.softmax(logits) - model = keras.Model(inputs, outputs) # Regularize the most recent activations of a post-dropout layer. sample_activations = unregularized_model.get_layer( index=-2).get_output_at(-1) regularization_loss = keras.backend.mean(sample_activations) - model.add_loss(regularization_loss) - model.add_metric( + unregularized_model.add_loss(regularization_loss) + unregularized_model.add_metric( regularization_loss, aggregation='mean', name='regularization_loss') + inputs = keras.Input(unregularized_model.inputs[0].shape[1:]) + logits = unregularized_model(inputs, training=training) + outputs = keras.activations.softmax(logits) + model = keras.Model(inputs, outputs) return model # Make and compile models. diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 0b0121f521e..eaffb90e64b 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -112,11 +112,12 @@ def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map): # then call node.inbound_layer on them. if all( tensor in tensor_map for tensor in nest.flatten(node.input_tensors)): - computed_tensors = nest.map_structure(lambda t: tensor_map[t], - node.input_tensors) # Call layer. - kwargs = node.arguments or {} - output_tensors = layer(computed_tensors, **kwargs) + args = nest.map_structure(lambda t: tensor_map.get(t, t), + node.call_args) + kwargs = nest.map_structure(lambda t: tensor_map.get(t, t), + node.call_kwargs) + output_tensors = layer(*args, **kwargs) # Thread-safe way to keep track of what node was created. first_output_tensor = nest.flatten(output_tensors)[0] diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index f56d55b18d5..6c2037c4f4b 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -547,7 +547,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): else: return inputs - t = array_ops.sequence_mask(1) + t = self.evaluate(array_ops.sequence_mask(1)) inputs = keras.layers.Input(shape=(3)) model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t)) diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index 1dfd2f517c6..d2d3d919fff 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -55,7 +55,7 @@ def get_source_inputs(tensor, layer=None, node_index=None): return [tensor] else: node = layer._inbound_nodes[node_index] - if not node.inbound_layers: + if node.is_input: # Reached an Input layer, stop recursion. return nest.flatten(node.input_tensors) else: @@ -140,7 +140,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): nodes = [] for v in nodes_by_depth: if (len(v) > 1) or (len(v) == 1 and - len(nest.flatten(v[0].inbound_layers)) > 1): + len(nest.flatten(v[0].keras_inputs)) > 1): # if the model has multiple nodes # or if the nodes have multiple inbound_layers # the model is no longer sequential