Stops special-casing first argument in Keras's internal Node representation. This makes the special-casing-behavior in __call__ more explicit, and make it easier to reduce how much special casing of the first argument Keras does moving forward.

It also allows us to simplify internal node-walking code.

This may come with a subtle behavior change/simplification in the interaction between `training`, learning phase scopes and model construction. Before, it seems that constructing a network model in a learning phase scope would cause that model to permanently use that learning phase in some future code regardless of what values of `training` were passed in.

The code underlying this behavior seems to have been pretty fragile and hard to guarantee. It was also quite likely unintentional or buggy in many cases. After this change, setting a learning phase scope will continue affecting call-time behavior when `training` isn't explicitly passed in, but building a model inside of a learning phase scope won't force it to always use that learning phase.

(Passing `training=...` at construction time will continue to be the recommended method of freezing behavior at call time)

PiperOrigin-RevId: 308436442
Change-Id: I8cb8922a6e3cd219a1771328dda3003287978b39
This commit is contained in:
Tomer Kaftan 2020-04-25 13:43:51 -07:00 committed by TensorFlower Gardener
parent ef72ad0c7a
commit 9485250f44
15 changed files with 751 additions and 599 deletions

View File

@ -133,6 +133,7 @@ class _DummyEagerGraph(threading.local):
# get a different key. # get a different key.
super(_DummyEagerGraph, self).__init__() super(_DummyEagerGraph, self).__init__()
self.key = _DummyEagerGraph._WeakReferencableClass() self.key = _DummyEagerGraph._WeakReferencableClass()
self.learning_phase_is_set = False
_DUMMY_EAGER_GRAPH = _DummyEagerGraph() _DUMMY_EAGER_GRAPH = _DummyEagerGraph()
@ -295,6 +296,7 @@ def clear_session():
_SESSION.session = None _SESSION.session = None
graph = get_graph() graph = get_graph()
with graph.as_default(): with graph.as_default():
_DUMMY_EAGER_GRAPH.learning_phase_is_set = False
_GRAPH_LEARNING_PHASES.clear() _GRAPH_LEARNING_PHASES.clear()
# Create the learning phase placeholder in graph using the default factory. # Create the learning phase placeholder in graph using the default factory.
_GRAPH_LEARNING_PHASES.setdefault(graph) _GRAPH_LEARNING_PHASES.setdefault(graph)
@ -351,7 +353,7 @@ def learning_phase():
def global_learning_phase_is_set(): def global_learning_phase_is_set():
return _DUMMY_EAGER_GRAPH.key in _GRAPH_LEARNING_PHASES return _DUMMY_EAGER_GRAPH.learning_phase_is_set
def _mark_func_graph_as_unsaveable(graph, learning_phase): def _mark_func_graph_as_unsaveable(graph, learning_phase):
@ -420,6 +422,7 @@ def set_learning_phase(value):
if context.executing_eagerly(): if context.executing_eagerly():
# In an eager context, the learning phase values applies to both the eager # In an eager context, the learning phase values applies to both the eager
# context and the internal Keras graph. # context and the internal Keras graph.
_DUMMY_EAGER_GRAPH.learning_phase_is_set = True
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] = value
_GRAPH_LEARNING_PHASES[get_graph()] = value _GRAPH_LEARNING_PHASES[get_graph()] = value
@ -451,11 +454,14 @@ def learning_phase_scope(value):
_DUMMY_EAGER_GRAPH.key, None) _DUMMY_EAGER_GRAPH.key, None)
previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None) previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None)
learning_phase_previously_set = _DUMMY_EAGER_GRAPH.learning_phase_is_set
try: try:
set_learning_phase(value) set_learning_phase(value)
yield yield
finally: finally:
# Restore learning phase to initial value. # Restore learning phase to initial value.
if not learning_phase_previously_set:
_DUMMY_EAGER_GRAPH.learning_phase_is_set = False
with ops.init_scope(): with ops.init_scope():
if context.executing_eagerly(): if context.executing_eagerly():
if previous_eager_value is not None: if previous_eager_value is not None:

View File

@ -495,6 +495,20 @@ tf_py_test(
], ],
) )
tf_py_test(
name = "node_test",
size = "medium",
srcs = ["node_test.py"],
python_version = "PY3",
shard_count = 3,
deps = [
":base_layer",
":engine",
"//tensorflow/python/keras:testing_utils",
"//tensorflow/python/keras/utils:layer_utils",
],
)
tf_py_test( tf_py_test(
name = "base_layer_test", name = "base_layer_test",
size = "medium", size = "medium",

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections import copy
import functools import functools
import itertools import itertools
import threading import threading
@ -797,15 +797,22 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
raise RuntimeError( raise RuntimeError(
'You must call `super().__init__()` in the layer constructor.') 'You must call `super().__init__()` in the layer constructor.')
# Grab the first positional or keyword argument. # 'inputs` (the first arg in the method spec) is special cased in
if args: # layer call due to historical reasons.
inputs = args[0] # This special casing currently takes the form of:
args = args[1:] # - 'inputs' must be explicitly passed. A layer cannot have zero arguments,
elif self._call_fn_args[0] in kwargs: # and inputs cannot have been provided via the default value of a kwarg.
inputs = kwargs.pop(self._call_fn_args[0]) # - numpy/scalar values in `inputs` get converted to tensors
else: # - implicit masks / mask metadata are only collected from 'inputs`
raise ValueError( # - Layers are built using shape info from 'inputs' only
'The first argument to `Layer.call` must always be passed.') # - input_spec compatibility is only checked against `inputs`
# - checking if a layer has ragged tensor support is only done against
# `inputs`
# - mixed precision casting (autocast) is only applied to `inputs`,
# not to any other argument.
# - configuring the Functional API SavedModel saving spec for deciding what
# should be serialized during SavedModel saving
inputs, args, kwargs = self._split_out_first_arg(args, kwargs)
call_context = base_layer_utils.call_context() call_context = base_layer_utils.call_context()
input_list = nest.flatten(inputs) input_list = nest.flatten(inputs)
@ -814,6 +821,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# This is always the case in graph mode. It can also be the case in eager # This is always the case in graph mode. It can also be the case in eager
# mode when all inputs can be traced back to `keras.Input()` (when building # mode when all inputs can be traced back to `keras.Input()` (when building
# models using the functional API). # models using the functional API).
# TODO(kaftan): make this not special case inputs. Instead
# build a functional api model if *any* *arg or **kwarg is symbolic,
# even if part of the data structure in that arg is not symbolic.
build_graph = tf_utils.are_all_symbolic_tensors(input_list) build_graph = tf_utils.are_all_symbolic_tensors(input_list)
# Accept NumPy and scalar inputs by converting to Tensors. # Accept NumPy and scalar inputs by converting to Tensors.
@ -838,16 +848,17 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
mask_arg_passed_by_framework = True mask_arg_passed_by_framework = True
kwargs['mask'] = input_masks kwargs['mask'] = input_masks
# If `training` argument was not explicitly passed, propagate `training` # If `training` argument is None or not explicitly passed,
# value from this layer's calling layer. # propagate `training` value from this layer's calling layer.
training_value = None
training_arg_passed_by_framework = False training_arg_passed_by_framework = False
# Priority 1: `training` was explicitly passed. # Priority 1: `training` was explicitly passed.
if self._call_arg_was_passed('training', args, kwargs): if self._call_arg_was_passed('training', args, kwargs):
training_value = self._get_call_arg_value('training', args, kwargs) training_value = self._get_call_arg_value('training', args, kwargs)
if not self._expects_training_arg: if not self._expects_training_arg:
kwargs.pop('training') kwargs.pop('training')
else:
training_value = None if training_value is None:
# Priority 2: `training` was passed to a parent layer. # Priority 2: `training` was passed to a parent layer.
if call_context.training is not None: if call_context.training is not None:
training_value = call_context.training training_value = call_context.training
@ -867,12 +878,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
training_value = math_ops.cast(training_value, dtypes.bool) training_value = math_ops.cast(training_value, dtypes.bool)
else: else:
training_value = bool(training_value) training_value = bool(training_value)
kwargs['training'] = training_value args, kwargs = self._set_call_arg_value(
'training', training_value, args, kwargs)
training_arg_passed_by_framework = True training_arg_passed_by_framework = True
# Only create Keras history if at least one tensor originates from a # Only create Keras history if at least one tensor originates from a
# `keras.Input`. Otherwise this Layer may be being used outside the Keras # `keras.Input`. Otherwise this Layer may be being used outside the Keras
# framework. # framework.
# TODO(kaftan): make this not special case inputs
if build_graph and base_layer_utils.needs_keras_history(inputs): if build_graph and base_layer_utils.needs_keras_history(inputs):
base_layer_utils.create_keras_history(inputs) base_layer_utils.create_keras_history(inputs)
@ -953,13 +966,16 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
raise ValueError('A layer\'s `call` method should return a ' raise ValueError('A layer\'s `call` method should return a '
'Tensor or a list of Tensors, not None ' 'Tensor or a list of Tensors, not None '
'(layer: ' + self.name + ').') '(layer: ' + self.name + ').')
# TODO(kaftan): This should be 'any' and check all args
if base_layer_utils.have_all_keras_metadata(inputs): if base_layer_utils.have_all_keras_metadata(inputs):
if training_arg_passed_by_framework: if training_arg_passed_by_framework:
kwargs.pop('training') args, kwargs = self._set_call_arg_value(
'training', None, args, kwargs, pop_kwarg_if_none=True)
if mask_arg_passed_by_framework: if mask_arg_passed_by_framework:
kwargs.pop('mask') kwargs.pop('mask')
inputs, outputs = self._set_connectivity_metadata_( # Node connectivity does not special-case the first argument.
inputs, outputs, args, kwargs) outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
outputs)
self._handle_activity_regularization(inputs, outputs) self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks) self._set_mask_metadata(inputs, outputs, input_masks)
if hasattr(self, '_set_inputs') and not self.inputs: if hasattr(self, '_set_inputs') and not self.inputs:
@ -2299,70 +2315,45 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
args_dict = dict(zip(call_fn_args, args)) args_dict = dict(zip(call_fn_args, args))
return args_dict[arg_name] return args_dict[arg_name]
def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): def _set_call_arg_value(
self, arg_name, new_value, args,
kwargs, inputs_in_args=False, pop_kwarg_if_none=False):
arg_pos = self._call_fn_arg_positions.get(arg_name, None)
if arg_pos is not None:
if not inputs_in_args:
# Ignore `inputs` arg.
arg_pos = arg_pos - 1
if len(args) > arg_pos:
args = list(args)
args[arg_pos] = new_value
return args, kwargs
if new_value is None and pop_kwarg_if_none:
kwargs.pop(arg_name, None)
else:
kwargs[arg_name] = new_value
return args, kwargs
# If the layer returns tensors from its inputs, unmodified, def _set_connectivity_metadata(self, args, kwargs, outputs):
# we copy them to avoid loss of tensor metadata. # If the layer returns tensors from its inputs unmodified,
output_ls = nest.flatten(outputs) # we copy them to avoid loss of KerasHistory metadata.
inputs_ls = object_identity.ObjectIdentitySet(nest.flatten(inputs)) flat_outputs = nest.flatten(outputs)
output_ls_copy = [] flat_inputs = nest.flatten((args, kwargs))
for x in output_ls: inputs_set = object_identity.ObjectIdentitySet(flat_inputs)
if x in inputs_ls: outputs_copy = []
for x in flat_outputs:
if x in inputs_set:
with backend.name_scope(self.name): with backend.name_scope(self.name):
x = array_ops.identity(x) x = array_ops.identity(x)
output_ls_copy.append(x) outputs_copy.append(x)
outputs = nest.pack_sequence_as(outputs, output_ls_copy) outputs = nest.pack_sequence_as(outputs, outputs_copy)
# Ignore `inputs` arg. # Create node, Node wires itself to inbound and outbound layers.
arguments = dict(zip(self._call_fn_args[1:], args)) # The Node constructor actually updates this layer's self._inbound_nodes,
arguments.update(kwargs) # sets _keras_history on the outputs, and adds itself to the
# `_outbound_nodes` of the layers that produced the inputs to this
# Add an inbound node to the layer, so it can keep track of this call. # layer call.
# This updates the layer history of the output tensor(s). node_module.Node(self, call_args=args, call_kwargs=kwargs, outputs=outputs)
self._add_inbound_node( return outputs
input_tensors=inputs, output_tensors=outputs, arguments=arguments)
return inputs, outputs
def _add_inbound_node(self,
input_tensors,
output_tensors,
arguments=None):
"""Internal method to create an inbound node for the layer.
Arguments:
input_tensors: list of input tensors.
output_tensors: list of output tensors.
arguments: dictionary of keyword arguments that were passed to the
`call` method of the layer at the call that created the node.
"""
inbound_layers = nest.map_structure(lambda t: t._keras_history.layer,
input_tensors)
node_indices = nest.map_structure(lambda t: t._keras_history.node_index,
input_tensors)
tensor_indices = nest.map_structure(lambda t: t._keras_history.tensor_index,
input_tensors)
# Create node, add it to inbound nodes.
node_module.Node(
self,
inbound_layers=inbound_layers,
node_indices=node_indices,
tensor_indices=tensor_indices,
input_tensors=input_tensors,
output_tensors=output_tensors,
arguments=arguments)
# Update tensor history metadata.
# The metadata attribute consists of
# 1) a layer instance
# 2) a node index for the layer
# 3) a tensor index for the node.
# The allows layer reuse (multiple nodes per layer) and multi-output
# or multi-input layers (e.g. a layer can return multiple tensors,
# and each can be sent to a different layer).
for i, tensor in enumerate(nest.flatten(output_tensors)):
tensor._keras_history = KerasHistory(self,
len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access
def _get_node_attribute_at_index(self, node_index, attr, attr_name): def _get_node_attribute_at_index(self, node_index, attr, attr_name):
"""Private utility to retrieves an attribute (e.g. inputs) from a node. """Private utility to retrieves an attribute (e.g. inputs) from a node.
@ -2706,6 +2697,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
return all_args[1:] return all_args[1:]
return all_args return all_args
@property
@tracking.cached_per_instance
def _call_fn_arg_positions(self):
call_fn_arg_positions = dict()
for pos, arg in enumerate(self._call_fn_args):
call_fn_arg_positions[arg] = pos
return call_fn_arg_positions
@property @property
@tracking.cached_per_instance @tracking.cached_per_instance
def _call_accepts_kwargs(self): def _call_accepts_kwargs(self):
@ -2743,6 +2742,21 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
seen_weights.add(w) seen_weights.add(w)
return output return output
def _split_out_first_arg(self, args, kwargs):
# Grab the argument corresponding to the first argument in the
# layer's `call` method spec. This will either be the first positional
# argument, or it will be provided as a keyword argument.
if args:
inputs = args[0]
args = args[1:]
elif self._call_fn_args[0] in kwargs:
kwargs = copy.copy(kwargs)
inputs = kwargs.pop(self._call_fn_args[0])
else:
raise ValueError(
'The first argument to `Layer.call` must always be passed.')
return inputs, args, kwargs
# SavedModel properties. Please see keras/saving/saved_model for details. # SavedModel properties. Please see keras/saving/saved_model for details.
@property @property
@ -2952,31 +2966,6 @@ class AddMetric(Layer):
return config return config
class KerasHistory(
collections.namedtuple('KerasHistory',
['layer', 'node_index', 'tensor_index'])):
"""Tracks the Layer call that created a Tensor, for Keras Graph Networks.
During construction of Keras Graph Networks, this metadata is added to
each Tensor produced as the output of a Layer, starting with an
`InputLayer`. This allows Keras to track how each Tensor was produced, and
this information is later retraced by the `keras.engine.Network` class to
reconstruct the Keras Graph Network.
Attributes:
layer: The Layer that produced the Tensor.
node_index: The specific call to the Layer that produced this Tensor. Layers
can be called multiple times in order to share weights. A new node is
created every time a Layer is called.
tensor_index: The output index for this Tensor. Always zero if the Layer
that produced this Tensor only has one output. Nested structures of
Tensors are deterministically assigned an index via `nest.flatten`.
"""
# Added to maintain memory and performance characteristics of `namedtuple`
# while subclassing.
__slots__ = ()
# Avoid breaking users who directly import this symbol from this file. # Avoid breaking users who directly import this symbol from this file.
# TODO(fchollet): remove this. # TODO(fchollet): remove this.
InputSpec = input_spec.InputSpec # pylint:disable=invalid-name InputSpec = input_spec.InputSpec # pylint:disable=invalid-name

View File

@ -254,8 +254,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
op_layer = base_layer.TensorFlowOpLayer( op_layer = base_layer.TensorFlowOpLayer(
node_def, constants=constants, name=name) node_def, constants=constants, name=name)
created_layers.append(op_layer) created_layers.append(op_layer)
op_layer._add_inbound_node( # pylint: disable=protected-access op_layer._set_connectivity_metadata( # pylint: disable=protected-access
layer_inputs, op.outputs) args=(layer_inputs,),
kwargs={},
outputs=op.outputs)
processed_ops.update([op]) processed_ops.update([op])
return processed_ops, created_layers return processed_ops, created_layers

View File

@ -45,7 +45,6 @@ from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_spec from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.engine import node as node_module
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.mixed_precision.experimental import policy
@ -698,16 +697,17 @@ class Layer(base_layer.Layer):
mask_arg_passed_by_framework = True mask_arg_passed_by_framework = True
kwargs['mask'] = input_masks kwargs['mask'] = input_masks
# If `training` argument was not explicitly passed, propagate `training` # If `training` argument is None or not explicitly passed,
# value from this layer's calling layer. # propagate `training` value from this layer's calling layer.
training_value = None
training_arg_passed_by_framework = False training_arg_passed_by_framework = False
# Priority 1: `training` was explicitly passed. # Priority 1: `training` was explicitly passed.
if self._call_arg_was_passed('training', args, kwargs): if self._call_arg_was_passed('training', args, kwargs):
training_value = self._get_call_arg_value('training', args, kwargs) training_value = self._get_call_arg_value('training', args, kwargs)
if not self._expects_training_arg: if not self._expects_training_arg:
kwargs.pop('training') kwargs.pop('training')
else:
training_value = None if training_value is None:
# Priority 2: `training` was passed to a parent layer. # Priority 2: `training` was passed to a parent layer.
if call_context.training is not None: if call_context.training is not None:
training_value = call_context.training training_value = call_context.training
@ -727,7 +727,8 @@ class Layer(base_layer.Layer):
training_value = math_ops.cast(training_value, dtypes.bool) training_value = math_ops.cast(training_value, dtypes.bool)
else: else:
training_value = bool(training_value) training_value = bool(training_value)
kwargs['training'] = training_value args, kwargs = self._set_call_arg_value(
'training', training_value, args, kwargs)
training_arg_passed_by_framework = True training_arg_passed_by_framework = True
# Only create Keras history if at least one tensor originates from a # Only create Keras history if at least one tensor originates from a
@ -798,11 +799,12 @@ class Layer(base_layer.Layer):
'(layer: ' + self.name + ').') '(layer: ' + self.name + ').')
if base_layer_utils.have_all_keras_metadata(inputs): if base_layer_utils.have_all_keras_metadata(inputs):
if training_arg_passed_by_framework: if training_arg_passed_by_framework:
kwargs.pop('training') args, kwargs = self._set_call_arg_value(
'training', None, args, kwargs, pop_kwarg_if_none=True)
if mask_arg_passed_by_framework: if mask_arg_passed_by_framework:
kwargs.pop('mask') kwargs.pop('mask')
inputs, outputs = self._set_connectivity_metadata_( outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
inputs, outputs, args, kwargs) outputs)
self._handle_activity_regularization(inputs, outputs) self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks) self._set_mask_metadata(inputs, outputs, input_masks)
if hasattr(self, '_set_inputs') and not self.inputs: if hasattr(self, '_set_inputs') and not self.inputs:
@ -2005,70 +2007,23 @@ class Layer(base_layer.Layer):
args_dict = dict(zip(call_fn_args, args)) args_dict = dict(zip(call_fn_args, args))
return args_dict[arg_name] return args_dict[arg_name]
def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): def _set_call_arg_value(
self, arg_name, new_value, args,
# If the layer returns tensors from its inputs, unmodified, kwargs, inputs_in_args=False, pop_kwarg_if_none=False):
# we copy them to avoid loss of tensor metadata. arg_pos = self._call_fn_arg_positions.get(arg_name, None)
output_ls = nest.flatten(outputs) if arg_pos is not None:
inputs_ls = object_identity.ObjectIdentitySet(nest.flatten(inputs)) if not inputs_in_args:
output_ls_copy = [] # Ignore `inputs` arg.
for x in output_ls: arg_pos = arg_pos - 1
if x in inputs_ls: if len(args) > arg_pos:
with backend.name_scope(self.name): args = list(args)
x = array_ops.identity(x) args[arg_pos] = new_value
output_ls_copy.append(x) return args, kwargs
outputs = nest.pack_sequence_as(outputs, output_ls_copy) if new_value is None and pop_kwarg_if_none:
kwargs.pop(arg_name, None)
# Ignore `inputs` arg. else:
arguments = dict(zip(self._call_fn_args[1:], args)) kwargs[arg_name] = new_value
arguments.update(kwargs) return args, kwargs
# Add an inbound node to the layer, so it can keep track of this call.
# This updates the layer history of the output tensor(s).
self._add_inbound_node(
input_tensors=inputs, output_tensors=outputs, arguments=arguments)
return inputs, outputs
def _add_inbound_node(self,
input_tensors,
output_tensors,
arguments=None):
"""Internal method to create an inbound node for the layer.
Arguments:
input_tensors: list of input tensors.
output_tensors: list of output tensors.
arguments: dictionary of keyword arguments that were passed to the
`call` method of the layer at the call that created the node.
"""
inbound_layers = nest.map_structure(lambda t: t._keras_history.layer,
input_tensors)
node_indices = nest.map_structure(lambda t: t._keras_history.node_index,
input_tensors)
tensor_indices = nest.map_structure(lambda t: t._keras_history.tensor_index,
input_tensors)
# Create node, add it to inbound nodes.
node_module.Node(
self,
inbound_layers=inbound_layers,
node_indices=node_indices,
tensor_indices=tensor_indices,
input_tensors=input_tensors,
output_tensors=output_tensors,
arguments=arguments)
# Update tensor history metadata.
# The metadata attribute consists of
# 1) a layer instance
# 2) a node index for the layer
# 3) a tensor index for the node.
# The allows layer reuse (multiple nodes per layer) and multi-output
# or multi-input layers (e.g. a layer can return multiple tensors,
# and each can be sent to a different layer).
for i, tensor in enumerate(nest.flatten(output_tensors)):
tensor._keras_history = KerasHistory(self,
len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access
def _get_node_attribute_at_index(self, node_index, attr, attr_name): def _get_node_attribute_at_index(self, node_index, attr, attr_name):
"""Private utility to retrieves an attribute (e.g. inputs) from a node. """Private utility to retrieves an attribute (e.g. inputs) from a node.
@ -2395,6 +2350,14 @@ class Layer(base_layer.Layer):
return all_args[1:] return all_args[1:]
return all_args return all_args
@property
@tracking.cached_per_instance
def _call_fn_arg_positions(self):
call_fn_arg_positions = dict()
for pos, arg in enumerate(self._call_fn_args):
call_fn_arg_positions[arg] = pos
return call_fn_arg_positions
@property @property
@tracking.cached_per_instance @tracking.cached_per_instance
def _call_accepts_kwargs(self): def _call_accepts_kwargs(self):

View File

@ -164,17 +164,9 @@ class InputLayer(base_layer.Layer):
self.is_placeholder = False self.is_placeholder = False
self._batch_input_shape = tuple(input_tensor.shape.as_list()) self._batch_input_shape = tuple(input_tensor.shape.as_list())
# Create an input node to add to self.outbound_node # Create an input node.
# and set output_tensors' _keras_history.
input_tensor._keras_history = base_layer.KerasHistory(self, 0, 0)
input_tensor._keras_mask = None input_tensor._keras_mask = None
node_module.Node( node_module.Node(layer=self, outputs=input_tensor)
self,
inbound_layers=[],
node_indices=[],
tensor_indices=[],
input_tensors=[input_tensor],
output_tensors=[input_tensor])
def get_config(self): def get_config(self):
config = { config = {
@ -294,8 +286,8 @@ def Input( # pylint: disable=invalid-name
# Return tensor including `_keras_history`. # Return tensor including `_keras_history`.
# Note that in this case train_output and test_output are the same pointer. # Note that in this case train_output and test_output are the same pointer.
outputs = input_layer._inbound_nodes[0].output_tensors outputs = input_layer._inbound_nodes[0].outputs
if len(outputs) == 1: if isinstance(outputs, list) and len(outputs) == 1:
return outputs[0] return outputs[0]
else: else:
return outputs return outputs

View File

@ -25,13 +25,11 @@ import itertools
import json import json
import os import os
import numpy as np
import six import six
from six.moves import zip # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import func_graph from tensorflow.python.framework import func_graph
@ -42,7 +40,6 @@ from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import compile_utils from tensorflow.python.keras.engine import compile_utils
from tensorflow.python.keras.engine import input_layer as input_layer_module from tensorflow.python.keras.engine import input_layer as input_layer_module
from tensorflow.python.keras.engine import node as node_module
from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import save from tensorflow.python.keras.saving import save
@ -307,15 +304,6 @@ class Network(base_layer.Layer):
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
layer._attribute_sentinel.add_parent(self._attribute_sentinel) layer._attribute_sentinel.add_parent(self._attribute_sentinel)
# Create the node linking internal inputs to internal outputs.
node_module.Node(
outbound_layer=self,
inbound_layers=[],
node_indices=[],
tensor_indices=[],
input_tensors=self._nested_inputs,
output_tensors=self._nested_outputs)
# Build self.input_names and self.output_names. # Build self.input_names and self.output_names.
self._set_output_names() self._set_output_names()
self.input_names = [] self.input_names = []
@ -337,6 +325,82 @@ class Network(base_layer.Layer):
self._compute_tensor_usage_count() self._compute_tensor_usage_count()
self._set_save_spec(self._nested_inputs) self._set_save_spec(self._nested_inputs)
@property
def input(self):
"""Retrieves the input tensor(s) of a layer.
Only applicable if the layer has exactly one input,
i.e. if it is connected to one incoming layer.
Returns:
Input tensor or list of input tensors.
Raises:
RuntimeError: If called in Eager mode.
AttributeError: If no inbound nodes are found.
"""
if self._is_graph_network:
return self._nested_inputs
return super(Network, self).input
@property
def input_shape(self):
"""Retrieves the input shape(s) of a layer.
Only applicable if the layer has exactly one input,
i.e. if it is connected to one incoming layer, or if all inputs
have the same shape.
Returns:
Input shape, as an integer shape tuple
(or list of shape tuples, one tuple per input tensor).
Raises:
AttributeError: if the layer has no defined input_shape.
RuntimeError: if called in Eager mode.
"""
if self._is_graph_network:
return nest.map_structure(backend.int_shape, self.input)
return super(Network, self).input_shape
@property
def output(self):
"""Retrieves the output tensor(s) of a layer.
Only applicable if the layer has exactly one output,
i.e. if it is connected to one incoming layer.
Returns:
Output tensor or list of output tensors.
Raises:
AttributeError: if the layer is connected to more than one incoming
layers.
RuntimeError: if called in Eager mode.
"""
if self._is_graph_network:
return self._nested_outputs
return super(Network, self).output
@property
def output_shape(self):
"""Retrieves the output shape(s) of a layer.
Only applicable if the layer has one output,
or if all outputs have the same shape.
Returns:
Output shape, as an integer shape tuple
(or list of shape tuples, one tuple per output tensor).
Raises:
AttributeError: if the layer has no defined output shape.
RuntimeError: if called in Eager mode.
"""
if self._is_graph_network:
return nest.map_structure(backend.int_shape, self.output)
return super(Network, self).output_shape
def _set_output_names(self): def _set_output_names(self):
"""Assigns unique names to the Network's outputs. """Assigns unique names to the Network's outputs.
@ -700,8 +764,7 @@ class Network(base_layer.Layer):
' implement a `call` method.') ' implement a `call` method.')
return self._run_internal_graph( return self._run_internal_graph(
inputs, training=training, mask=mask, inputs, training=training, mask=mask)
convert_kwargs_to_constants=base_layer_utils.call_context().saving)
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):
if not self._is_graph_network: if not self._is_graph_network:
@ -741,20 +804,20 @@ class Network(base_layer.Layer):
for depth in depth_keys: for depth in depth_keys:
nodes = self._nodes_by_depth[depth] nodes = self._nodes_by_depth[depth]
for node in nodes: for node in nodes:
# This is always a single layer, never a list. layer = node.layer
layer = node.outbound_layer
if layer in self._input_layers: if layer in self._input_layers:
# We've already covered the input layers # We've already covered the input layers
# a few lines above. # a few lines above.
continue continue
# Potentially redundant list, # Get the input shapes for the first argument of the node
# same size as node.input_tensors.
layer_input_shapes = [] layer_input_shapes = []
for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound(): layer_inputs = node.call_args[0]
input_layer_key = inbound_layer.name + '_%s_%s' % (node_id, for layer_input in nest.flatten(layer_inputs):
tensor_id) kh = layer_input._keras_history
input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index,
kh.tensor_index)
layer_input_shapes.append(layers_to_output_shapes[input_layer_key]) layer_input_shapes.append(layers_to_output_shapes[input_layer_key])
layer_input_shapes = nest.pack_sequence_as(node.inbound_layers, layer_input_shapes = nest.pack_sequence_as(layer_inputs,
layer_input_shapes) layer_input_shapes)
# Layers expect shapes to be tuples for `compute_output_shape`. # Layers expect shapes to be tuples for `compute_output_shape`.
layer_input_shapes = tf_utils.convert_shapes( layer_input_shapes = tf_utils.convert_shapes(
@ -782,8 +845,7 @@ class Network(base_layer.Layer):
# Return shapes as TensorShapes. # Return shapes as TensorShapes.
return output_shapes return output_shapes
def _run_internal_graph(self, inputs, training=None, mask=None, def _run_internal_graph(self, inputs, training=None, mask=None):
convert_kwargs_to_constants=False):
"""Computes output tensors for new inputs. """Computes output tensors for new inputs.
# Note: # Note:
@ -793,21 +855,10 @@ class Network(base_layer.Layer):
inputs: Tensor or nested structure of Tensors. inputs: Tensor or nested structure of Tensors.
training: Boolean learning phase. training: Boolean learning phase.
mask: (Optional) Tensor or nested structure of Tensors. mask: (Optional) Tensor or nested structure of Tensors.
convert_kwargs_to_constants: Whether to convert Tensor kwargs to
constants. This is used when tracing the model call function during
saving to ensure that external tensors aren't captured.
Returns: Returns:
Two lists: output_tensors, output_masks Two lists: output_tensors, output_masks
""" """
# Note: masking support is relevant mainly for Keras.
# It cannot be factored out without having the fully reimplement the network
# calling logic on the Keras side. We choose to incorporate it in
# Network because 1) it may be useful to fully support in tf.layers in
# the future and 2) Keras is a major user of Network. If you don't
# use masking, it does not interfere with regular behavior at all and you
# can ignore it.
inputs = self._flatten_to_reference_inputs(inputs) inputs = self._flatten_to_reference_inputs(inputs)
if mask is None: if mask is None:
masks = [None for _ in range(len(inputs))] masks = [None for _ in range(len(inputs))]
@ -825,58 +876,26 @@ class Network(base_layer.Layer):
depth_keys = list(self._nodes_by_depth.keys()) depth_keys = list(self._nodes_by_depth.keys())
depth_keys.sort(reverse=True) depth_keys.sort(reverse=True)
# Ignore the InputLayers when computing the graph.
depth_keys = depth_keys[1:]
for depth in depth_keys: for depth in depth_keys:
nodes = self._nodes_by_depth[depth] nodes = self._nodes_by_depth[depth]
for node in nodes: for node in nodes:
# This is always a single layer, never a list. if node.is_input:
layer = node.outbound_layer continue # Input tensors already exist.
if all( if not all(
str(id(tensor)) in tensor_dict str(id(tensor)) in tensor_dict
for tensor in nest.flatten(node.input_tensors)): for tensor in nest.flatten(node.keras_inputs)):
continue # Node is not computable, try skipping.
# Call layer (reapplying ops to new inputs). layer = node.layer
computed_tensors = nest.map_structure( args, kwargs = node.map_arguments(tensor_dict)
lambda t: tensor_dict[str(id(t))].pop(), node.input_tensors) outputs = layer(*args, **kwargs)
# Ensure `training` arg propagation if applicable. # Update tensor_dict.
kwargs = copy.copy(node.arguments) if node.arguments else {} for x, y in zip(nest.flatten(node.outputs), nest.flatten(outputs)):
if convert_kwargs_to_constants: x_id = str(id(x))
kwargs = _map_tensors_to_constants(kwargs) tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
argspec = self._layer_call_argspecs[layer].args
if 'training' in argspec:
if 'training' not in kwargs or kwargs['training'] is None:
kwargs['training'] = training
elif (type(kwargs['training']) is ops.Tensor and # pylint: disable=unidiomatic-typecheck
any([
kwargs['training'] is x
for x in backend._GRAPH_LEARNING_PHASES.values()
])):
kwargs['training'] = training # Materialize placeholder.
# Map Keras tensors in kwargs to their computed value.
def _map_tensor_if_from_keras_layer(t):
if (isinstance(t,
(ops.Tensor, composite_tensor.CompositeTensor)) and
hasattr(t, '_keras_history')):
t_id = str(id(t))
return tensor_dict[t_id].pop()
return t
kwargs = nest.map_structure(_map_tensor_if_from_keras_layer, kwargs)
# Compute outputs.
output_tensors = layer(computed_tensors, **kwargs)
# Update tensor_dict.
for x, y in zip(
nest.flatten(node.output_tensors), nest.flatten(output_tensors)):
x_id = str(id(x))
tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
output_tensors = [] output_tensors = []
output_shapes = [] output_shapes = []
@ -1367,7 +1386,7 @@ class Network(base_layer.Layer):
# pylint: disable=protected-access # pylint: disable=protected-access
layer = x._keras_history.layer layer = x._keras_history.layer
if len(layer._inbound_nodes) > 1 or ( if len(layer._inbound_nodes) > 1 or (
layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): layer._inbound_nodes and not layer._inbound_nodes[0].is_input):
cls_name = self.__class__.__name__ cls_name = self.__class__.__name__
logging.warning(cls_name + ' inputs must come from ' logging.warning(cls_name + ' inputs must come from '
'`tf.keras.Input` (thus holding past layer metadata), ' '`tf.keras.Input` (thus holding past layer metadata), '
@ -1437,7 +1456,7 @@ class Network(base_layer.Layer):
def _get_min_depth(node): def _get_min_depth(node):
"""Gets the minimum depth at which node can be computed.""" """Gets the minimum depth at which node can be computed."""
min_depth = 0 min_depth = 0
for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True): for layer, node_id, _, _ in node.iterate_inbound():
inbound_node = layer._inbound_nodes[node_id] inbound_node = layer._inbound_nodes[node_id]
if inbound_node in node_to_depth: if inbound_node in node_to_depth:
min_depth = min(min_depth, node_to_depth[inbound_node]) min_depth = min(min_depth, node_to_depth[inbound_node])
@ -1465,8 +1484,8 @@ class Network(base_layer.Layer):
if depth is None: # Defer until inbound nodes are processed. if depth is None: # Defer until inbound nodes are processed.
unprocessed_nodes.append(node) unprocessed_nodes.append(node)
continue continue
node_key = _make_node_key(node.outbound_layer.name, node_key = _make_node_key(node.layer.name,
node.outbound_layer._inbound_nodes.index(node)) node.layer._inbound_nodes.index(node))
if node_key not in self._network_nodes: if node_key not in self._network_nodes:
node_to_depth[node] = depth node_to_depth[node] = depth
self._network_nodes.add(node_key) self._network_nodes.add(node_key)
@ -1506,21 +1525,13 @@ class Network(base_layer.Layer):
for depth in depth_keys: for depth in depth_keys:
for node in self._nodes_by_depth[depth]: for node in self._nodes_by_depth[depth]:
input_tensors = { input_tensors = {
str(id(tensor)) for tensor in nest.flatten(node.input_tensors) str(id(tensor)) for tensor in nest.flatten(node.keras_inputs)
} }
if input_tensors.issubset(available_tensors): if input_tensors.issubset(available_tensors):
kwargs = copy.copy(node.arguments) if node.arguments else {} for tensor in nest.flatten(node.keras_inputs):
for tensor in nest.flatten(kwargs):
if (isinstance(tensor,
(ops.Tensor, composite_tensor.CompositeTensor)) and
hasattr(tensor, '_keras_history')):
tensor_usage_count[str(id(tensor))] += 1
for tensor in nest.flatten(node.input_tensors):
tensor_usage_count[str(id(tensor))] += 1 tensor_usage_count[str(id(tensor))] += 1
for output_tensor in nest.flatten(node.output_tensors): for output_tensor in nest.flatten(node.outputs):
available_tensors.add(str(id(output_tensor))) available_tensors.add(str(id(output_tensor)))
for tensor in self.outputs: for tensor in self.outputs:
@ -1631,97 +1642,35 @@ def _map_graph_network(inputs, outputs):
Raises: Raises:
ValueError: In case the network is not valid (e.g. disconnected graph). ValueError: In case the network is not valid (e.g. disconnected graph).
""" """
# Network_nodes: set of nodes included in the graph of layers # "depth" is number of layers between output Node and the Node.
# (not all nodes included in the layers are relevant to the current graph). # Nodes are ordered from inputs -> outputs.
network_nodes = set() # ids of all nodes relevant to the Network nodes_in_decreasing_depth, layer_indices = _build_map(outputs)
network_nodes = {
_make_node_key(node.layer.name, node.layer._inbound_nodes.index(node))
for node in nodes_in_decreasing_depth
}
nodes_depths = {} # dict {node: depth value} nodes_depths = {} # dict {node: depth value}
layers_depths = {} # dict {layer: depth value} layers_depths = {} # dict {layer: depth value}
layer_indices = {} # dict {layer: index in traversal}
nodes_in_decreasing_depth = []
def build_map(tensor,
finished_nodes,
nodes_in_progress,
layer,
node_index,
tensor_index):
"""Builds a map of the graph of layers.
This recursively updates the map `layer_indices`,
the list `nodes_in_decreasing_depth` and the set `network_nodes`.
Arguments:
tensor: Some tensor in a graph.
finished_nodes: Set of nodes whose subgraphs have been traversed
completely. Useful to prevent duplicated work.
nodes_in_progress: Set of nodes that are currently active on the
recursion stack. Useful to detect cycles.
layer: Layer from which `tensor` comes from. If not provided,
will be obtained from `tensor._keras_history`.
node_index: Node index from which `tensor` comes from.
tensor_index: Tensor_index from which `tensor` comes from.
Raises:
ValueError: if a cycle is detected.
"""
node = layer._inbound_nodes[node_index] # pylint: disable=protected-access
# Prevent cycles.
if node in nodes_in_progress:
raise ValueError('The tensor ' + str(tensor) + ' at layer "' +
layer.name + '" is part of a cycle.')
# Don't repeat work for shared subgraphs
if node in finished_nodes:
return
node_key = _make_node_key(layer.name, node_index)
# Update network_nodes.
network_nodes.add(node_key)
# Store the traversal order for layer sorting.
if layer not in layer_indices:
layer_indices[layer] = len(layer_indices)
nodes_in_progress.add(node)
# Propagate to all previous tensors connected to this node.
for layer, node_index, tensor_index, tensor in node.iterate_inbound(
include_arguments=True):
build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index,
tensor_index)
finished_nodes.add(node)
nodes_in_progress.remove(node)
nodes_in_decreasing_depth.append(node)
finished_nodes = set()
nodes_in_progress = set()
for x in outputs:
layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
build_map(x, finished_nodes, nodes_in_progress,
layer=layer,
node_index=node_index,
tensor_index=tensor_index)
for node in reversed(nodes_in_decreasing_depth): for node in reversed(nodes_in_decreasing_depth):
# If the depth is not set, the node has no outbound nodes (depth 0). # If the depth is not set, the node has no outbound nodes (depth 0).
depth = nodes_depths.setdefault(node, 0) depth = nodes_depths.setdefault(node, 0)
# Update the depth of the corresponding layer # Update the depth of the corresponding layer
previous_depth = layers_depths.get(node.outbound_layer, 0) previous_depth = layers_depths.get(node.layer, 0)
# If we've seen this layer before at a higher depth, # If we've seen this layer before at a higher depth,
# we should use that depth instead of the node depth. # we should use that depth instead of the node depth.
# This is necessary for shared layers that have inputs at different # This is necessary for shared layers that have inputs at different
# depth levels in the graph. # depth levels in the graph.
depth = max(depth, previous_depth) depth = max(depth, previous_depth)
layers_depths[node.outbound_layer] = depth layers_depths[node.layer] = depth
nodes_depths[node] = depth nodes_depths[node] = depth
# Update the depth of inbound nodes. # Update the depth of inbound nodes.
# The "depth" of a node is the max of the depths # The "depth" of a node is the max of the depths
# of all nodes it is connected to + 1. # of all nodes it is connected to + 1.
for node_dep in node._get_all_node_dependencies(): for node_dep in node.parent_nodes:
previous_depth = nodes_depths.get(node_dep, 0) previous_depth = nodes_depths.get(node_dep, 0)
nodes_depths[node_dep] = max(depth + 1, previous_depth) nodes_depths[node_dep] = max(depth + 1, previous_depth)
@ -1773,9 +1722,9 @@ def _map_graph_network(inputs, outputs):
layers_with_complete_input = [] # To provide a better error msg. layers_with_complete_input = [] # To provide a better error msg.
for depth in depth_keys: for depth in depth_keys:
for node in nodes_by_depth[depth]: for node in nodes_by_depth[depth]:
layer = node.outbound_layer layer = node.layer
if layer: if layer and not node.is_input:
for x in nest.flatten(node.input_tensors): for x in nest.flatten(node.keras_inputs):
if id(x) not in computable_tensors: if id(x) not in computable_tensors:
raise ValueError('Graph disconnected: ' raise ValueError('Graph disconnected: '
'cannot obtain value for tensor ' + str(x) + 'cannot obtain value for tensor ' + str(x) +
@ -1783,7 +1732,7 @@ def _map_graph_network(inputs, outputs):
'The following previous layers ' 'The following previous layers '
'were accessed without issue: ' + 'were accessed without issue: ' +
str(layers_with_complete_input)) str(layers_with_complete_input))
for x in nest.flatten(node.output_tensors): for x in nest.flatten(node.outputs):
computable_tensors.add(id(x)) computable_tensors.add(id(x))
layers_with_complete_input.append(layer.name) layers_with_complete_input.append(layer.name)
@ -1798,6 +1747,68 @@ def _map_graph_network(inputs, outputs):
return network_nodes, nodes_by_depth, layers, layers_by_depth return network_nodes, nodes_by_depth, layers, layers_by_depth
def _build_map(outputs):
"""This method topologically sorts nodes in order from inputs to outputs.
It uses a depth-first search to topologically sort nodes that appear in the
_keras_history connectivity metadata of `outputs`.
Args:
outputs: the output tensors whose _keras_history metadata should be walked.
This may be an arbitrary nested structure.
Returns:
A tuple like (ordered_nodes, layer_to_first_traversal_index)
ordered_nodes: list of nodes appearing in the keras history, topologically
sorted from original inputs to the `outputs`.
(If outputs have different sets of ancestors, the inputs to one output
may appear after a different output).
layer_to_first_traversal_index:
A dict mapping layer to the traversal index in the DFS where it is
seen. Note: if a layer is shared by several nodes, the dict will only
store the index corresponding to the *first* time the layer seen.
"""
finished_nodes = set()
nodes_in_progress = set()
nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
layer_indices = {} # layer -> in traversal order.
for output in nest.flatten(outputs):
_build_map_helper(output, finished_nodes, nodes_in_progress,
nodes_in_decreasing_depth, layer_indices)
return nodes_in_decreasing_depth, layer_indices
def _build_map_helper(tensor, finished_nodes, nodes_in_progress,
nodes_in_decreasing_depth, layer_indices):
"""Recursive helper for `_build_map`."""
layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access
node = layer._inbound_nodes[node_index] # pylint: disable=protected-access
# Don't repeat work for shared subgraphs
if node in finished_nodes:
return
# Prevent cycles.
if node in nodes_in_progress:
raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name +
'" is part of a cycle.')
# Store the traversal order for layer sorting.
if layer not in layer_indices:
layer_indices[layer] = len(layer_indices)
# Propagate to all previous tensors connected to this node.
nodes_in_progress.add(node)
if not node.is_input:
for tensor in node.keras_inputs:
_build_map_helper(tensor, finished_nodes, nodes_in_progress,
nodes_in_decreasing_depth, layer_indices)
finished_nodes.add(node)
nodes_in_progress.remove(node)
nodes_in_decreasing_depth.append(node)
def _map_subgraph_network(inputs, outputs): def _map_subgraph_network(inputs, outputs):
"""Returns the nodes and layers in the topology from `inputs` to `outputs`. """Returns the nodes and layers in the topology from `inputs` to `outputs`.
@ -1820,36 +1831,6 @@ def _should_skip_first_node(layer):
return issubclass(layer.__class__, Network) and layer._is_graph_network return issubclass(layer.__class__, Network) and layer._is_graph_network
def _serialize_tensors(kwargs):
"""Serializes Tensors passed to `call`."""
def _serialize_keras_tensor(t):
"""Serializes a single Tensor passed to `call`."""
if hasattr(t, '_keras_history'):
kh = t._keras_history
return [kh.layer.name, kh.node_index, kh.tensor_index]
if isinstance(t, np.ndarray):
return t.tolist()
if isinstance(t, ops.Tensor):
return backend.get_value(t).tolist()
return t
return nest.map_structure(_serialize_keras_tensor, kwargs)
def _map_tensors_to_constants(kwargs):
def _map_to_constants(t):
if not hasattr(t, '_keras_history') and isinstance(t, ops.Tensor):
return constant_op.constant(backend.get_value(t))
return t
return nest.map_structure(_map_to_constants, kwargs)
def _deserialize_keras_tensors(kwargs, layer_map): def _deserialize_keras_tensors(kwargs, layer_map):
"""Deserializes Keras Tensors passed to `call`..""" """Deserializes Keras Tensors passed to `call`.."""
@ -1863,7 +1844,7 @@ def _deserialize_keras_tensors(kwargs, layer_map):
layer = layer_map[layer_name] layer = layer_map[layer_name]
node = layer._inbound_nodes[node_index] node = layer._inbound_nodes[node_index]
return nest.flatten(node.output_tensors)[tensor_index] return nest.flatten(node.outputs)[tensor_index]
return t return t
kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True) kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
@ -1961,7 +1942,7 @@ def reconstruct_from_config(config, custom_objects=None, created_layers=None):
return return
inbound_node = inbound_layer._inbound_nodes[inbound_node_index] inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
input_tensors.append( input_tensors.append(
nest.flatten(inbound_node.output_tensors)[inbound_tensor_index]) nest.flatten(inbound_node.outputs)[inbound_tensor_index])
input_tensors = nest.pack_sequence_as(node_data, input_tensors) input_tensors = nest.pack_sequence_as(node_data, input_tensors)
# Call layer on its inputs, thus creating the node # Call layer on its inputs, thus creating the node
# and building the layer if needed. # and building the layer if needed.
@ -2077,37 +2058,11 @@ def get_network_config(network, serialize_layer_fn=None):
filtered_inbound_nodes = [] filtered_inbound_nodes = []
for original_node_index, node in enumerate(layer._inbound_nodes): for original_node_index, node in enumerate(layer._inbound_nodes):
node_key = _make_node_key(layer.name, original_node_index) node_key = _make_node_key(layer.name, original_node_index)
if node_key in network._network_nodes: if node_key in network._network_nodes and not node.is_input:
# The node is relevant to the model: # The node is relevant to the model:
# add to filtered_inbound_nodes. # add to filtered_inbound_nodes.
if node.arguments: node_data = node.serialize(_make_node_key, node_conversion_map)
kwargs = _serialize_tensors(node.arguments) filtered_inbound_nodes.append(node_data)
try:
json.dumps(kwargs)
except TypeError:
logging.warning(
'Layer ' + layer.name +
' was passed non-serializable keyword arguments: ' +
str(node.arguments) + '. They will not be included '
'in the serialized model (and thus will be missing '
'at deserialization time).')
kwargs = {}
else:
kwargs = {}
if node.inbound_layers:
node_data = []
for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound():
node_key = _make_node_key(inbound_layer.name, node_id)
new_node_index = node_conversion_map.get(node_key, 0)
node_data.append(
tf_utils.ListWrapper(
[inbound_layer.name, new_node_index, tensor_id, kwargs]))
node_data = nest.pack_sequence_as(node.input_tensors, node_data)
if not nest.is_sequence(node_data):
node_data = [node_data]
# Convert ListWrapper to list for backwards compatible configs.
node_data = tf_utils.convert_inner_node_data(node_data)
filtered_inbound_nodes.append(node_data)
layer_config = serialize_layer_fn(layer) layer_config = serialize_layer_fn(layer)
layer_config['name'] = layer.name layer_config['name'] = layer.name

View File

@ -1785,6 +1785,7 @@ class AttrTrackingLayer(base_layer.Layer):
return super(AttrTrackingLayer, self).dynamic return super(AttrTrackingLayer, self).dynamic
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class CacheCorrectnessTest(keras_parameterized.TestCase): class CacheCorrectnessTest(keras_parameterized.TestCase):
def layer_and_network_test(self): def layer_and_network_test(self):
@ -1919,8 +1920,12 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
class MyLayer(base_layer.Layer): class MyLayer(base_layer.Layer):
def call(self, x, training=None): def call(self, x, training=None):
self.training = training if training is None:
return x return x * -1.0
elif training:
return x
else:
return x * 0.0
my_layer = MyLayer() my_layer = MyLayer()
x = np.ones((1, 10)) x = np.ones((1, 10))
@ -1929,9 +1934,8 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
outputs = my_layer(inputs, training=True) outputs = my_layer(inputs, training=True)
network = network_lib.Network(inputs, outputs) network = network_lib.Network(inputs, outputs)
network(x, training=False)
# Hard-coded value passed during construction is respected. # Hard-coded value passed during construction is respected.
self.assertTrue(my_layer.training) self.assertAllEqual(network(x, training=False), x)
inputs = input_layer_lib.Input(10) inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=False) outputs = my_layer(inputs, training=False)
@ -1939,19 +1943,16 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
network(x, training=True) network(x, training=True)
# Hard-coded value passed during construction is respected. # Hard-coded value passed during construction is respected.
self.assertFalse(my_layer.training) self.assertAllEqual(network(x, training=True), x * 0.0)
inputs = input_layer_lib.Input(10) inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=None) outputs = my_layer(inputs, training=None)
network = network_lib.Network(inputs, outputs) network = network_lib.Network(inputs, outputs)
network(x, training=True)
# `None` value passed during construction is overridden. # `None` value passed during construction is overridden.
self.assertTrue(my_layer.training) self.assertAllEqual(network(x, training=True), x)
network(x, training=False)
# `None` value passed during construction is overridden. # `None` value passed during construction is overridden.
self.assertFalse(my_layer.training) self.assertAllEqual(network(x, training=False), x * 0.0)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -18,10 +18,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import copy
import json
import numpy as np
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import serialization
class Node(object): class Node(object):
@ -33,163 +41,227 @@ class Node(object):
a node is added to `layer._outbound_nodes`. a node is added to `layer._outbound_nodes`.
Arguments: Arguments:
outbound_layer: the layer that takes layer: The Layer for the Layer.__call__ this node represents.
`input_tensors` and turns them into `output_tensors` call_args: The positional arguments the Layer was called with.
(the node gets created when the `call` call_kwargs: The keyword arguments the Layer was called with.
method of the layer was called). outputs: The outputs of the Layer.__call__
inbound_layers: a list of layers, the same length as `input_tensors`,
the layers from where `input_tensors` originate.
node_indices: a list of integers, the same length as `inbound_layers`.
`node_indices[i]` is the origin node of `input_tensors[i]`
(necessary since each inbound layer might have several nodes,
e.g. if the layer is being shared with a different data stream).
tensor_indices: a list of integers,
the same length as `inbound_layers`.
`tensor_indices[i]` is the index of `input_tensors[i]` within the
output of the inbound layer
(necessary since each inbound layer might
have multiple tensor outputs, with each one being
independently manipulable).
input_tensors: list of input tensors.
output_tensors: list of output tensors.
arguments: dictionary of keyword arguments that were passed to the
`call` method of the layer at the call that created the node.
`node_indices` and `tensor_indices` are basically fine-grained coordinates
describing the origin of the `input_tensors`.
A node from layer A to layer B is added to:
- A._outbound_nodes
- B._inbound_nodes
""" """
def __init__(self, def __init__(self,
outbound_layer, layer,
inbound_layers, call_args=None,
node_indices, call_kwargs=None,
tensor_indices, outputs=None):
input_tensors, call_args = [] if call_args is None else call_args
output_tensors, call_kwargs = {} if call_kwargs is None else call_kwargs
arguments=None): outputs = [] if outputs is None else outputs
# Layer instance (NOT a sequence)
if isinstance(outbound_layer, (list, tuple, dict)):
raise ValueError('`outbound_layer` should be a layer instance, '
'not a list, tuple, or, dict.')
# These arguments are user-provided. Copy them here so that future self.layer = layer
# user modifications do not affect the node's metadata. self.is_input = not call_args and not call_kwargs
input_tensors = nest.map_structure(lambda t: t, input_tensors)
output_tensors = nest.map_structure(lambda t: t, output_tensors)
arguments = nest.map_structure(lambda t: t, arguments)
# this is the layer that takes a nested structure of input tensors # These arguments are user-provided. Copy the structures here so that
# and turns them into a nested structure of output tensors. # future user modifications do not affect the node's metadata.
# the current node will be added to # We copy using map_structure rather than python's shallow or deep copy,
# the inbound_nodes of outbound_layer. # because the args can be data structures (so shallow copy is
self.outbound_layer = outbound_layer # insufficient), but individual values might not support copy.copy
# or be too expensive to deep copy.
call_args = nest.map_structure(lambda t: t, call_args)
call_kwargs = nest.map_structure(lambda t: t, call_kwargs)
self.outputs = nest.map_structure(lambda t: t, outputs)
self.call_args = call_args
self.call_kwargs = call_kwargs
# The following 3 properties describe where # Cached for performance.
# the input tensors come from: which layers, self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs))
# and for each layer, which node and which
# tensor output of each node.
# Nested structure of layer instances. # Create TensorFlowOpLayers if needed.
self.inbound_layers = inbound_layers for obj in self._flat_arguments:
# Nested structure of integers, 1:1 mapping with inbound_layers. if (isinstance(obj, ops.Tensor) and
self.node_indices = node_indices base_layer_utils.needs_keras_history(obj, ignore_call_context=True)):
# Nested of integers, 1:1 mapping with inbound_layers. base_layer_utils.create_keras_history(obj)
self.tensor_indices = tensor_indices
# Following 2 properties: self._keras_inputs = []
# tensor inputs and outputs of outbound_layer. self._keras_inputs_ids_and_indices = []
for i, ele in enumerate(self._flat_arguments):
if is_keras_tensor(ele):
self._keras_inputs.append(ele)
kt_id = str(id(ele))
kt_index = i
self._keras_inputs_ids_and_indices.append((kt_id, kt_index))
# Nested structure of tensors. 1:1 mapping with inbound_layers. # Wire up Node to Layers.
self.input_tensors = input_tensors self.layer._inbound_nodes.append(self)
# Nested structure of tensors, created by outbound_layer.call(). for kt in self.keras_inputs:
self.output_tensors = output_tensors inbound_layer = kt._keras_history.layer
if inbound_layer is not None: # `None` for `Input` tensors.
inbound_layer._outbound_nodes.append(self)
# Following 2 properties: input and output shapes. # Set metadata on outputs.
node_index = len(self.layer._inbound_nodes) - 1
for i, tensor in enumerate(nest.flatten(outputs)):
tensor._keras_history = KerasHistory(
layer=layer, node_index=node_index, tensor_index=i)
# Nested structure of shape tuples, shapes of input_tensors. @property
self.input_shapes = nest.map_structure(backend.int_shape, input_tensors) def keras_inputs(self):
# Nested structure of shape tuples, shapes of output_tensors. """Tensors input to this node that can be traced back to a `keras.Input`."""
self.output_shapes = nest.map_structure(backend.int_shape, output_tensors) return self._keras_inputs
# Optional keyword arguments to layer's `call`. @property
self.arguments = arguments def parent_nodes(self):
"""Returns all the `Node`s whose output this node immediately depends on."""
# Create Keras History for any Keras Tensors in `arguments`.
tensor_arguments = [
t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor)
]
for tensor_argument in tensor_arguments:
if base_layer_utils.needs_keras_history(
tensor_argument, ignore_call_context=True):
base_layer_utils.create_keras_history(tensor_argument)
# Add nodes to all layers involved.
for layer in nest.flatten(inbound_layers):
if layer is not None:
# For compatibility with external Keras, we use the deprecated
# accessor here.
layer.outbound_nodes.append(self)
# For compatibility with external Keras, we use the deprecated
# accessor here.
outbound_layer.inbound_nodes.append(self)
def iterate_inbound(self, include_arguments=False):
"""Returns a list of tuples representing the inbound data.
Arguments:
include_arguments: Whether to also iterate over any Keras Tensors
passed as args, kwargs.
Returns:
List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
"""
inputs_inbound = list(
zip(
nest.flatten(self.inbound_layers),
nest.flatten(self.node_indices),
nest.flatten(self.tensor_indices),
nest.flatten(self.input_tensors)))
if include_arguments:
keras_tensor_arguments = [
kt for kt in nest.flatten(self.arguments)
if hasattr(kt, '_keras_history')
]
def _get_inbound(keras_tensor):
kh = keras_tensor._keras_history
return kh.layer, kh.node_index, kh.tensor_index, keras_tensor
arguments_inbound = nest.map_structure(_get_inbound,
keras_tensor_arguments)
return inputs_inbound + arguments_inbound
else:
return inputs_inbound
def _get_all_node_dependencies(self):
"""Returns all of the nodes this node immediately depends on."""
node_deps = [] node_deps = []
for layer, node_index, _, _ in self.iterate_inbound(): for kt in self.keras_inputs:
node_deps.append(layer._inbound_nodes[node_index]) layer = kt._keras_history.layer
node_index = kt._keras_history.node_index
for arg in nest.flatten(self.arguments): if layer is not None: # `None` for `Input` tensors.
if isinstance(arg, ops.Tensor) and hasattr(arg, '_keras_history'): node_deps.append(layer._inbound_nodes[node_index])
kh = arg._keras_history
node_deps.append(kh.layer._inbound_nodes[kh.node_index])
return node_deps return node_deps
def get_config(self): def iterate_inbound(self):
inbound_names = nest.map_structure( """Yields tuples representing the data inbound from other nodes.
lambda layer: layer.name if layer else None, self.inbound_layers)
return { Yields:
'outbound_layer': self.outbound_layer.name, tuples like: (inbound_layer, node_index, tensor_index, tensor).
'inbound_layers': inbound_names, """
'node_indices': self.node_indices, for kt in self.keras_inputs:
'tensor_indices': self.tensor_indices keras_history = kt._keras_history
} layer = keras_history.layer
node_index = keras_history.node_index
tensor_index = keras_history.tensor_index
yield layer, node_index, tensor_index, kt
def map_arguments(self, tensor_dict):
"""Maps Keras Tensors to computed Tensors using `tensor_dict`."""
flat_arguments = copy.copy(self._flat_arguments)
for kt_id, kt_index in self._keras_inputs_ids_and_indices:
flat_arguments[kt_index] = tensor_dict[kt_id].pop()
args, kwargs = nest.pack_sequence_as(
(self.call_args, self.call_kwargs), flat_arguments)
return args, kwargs
def serialize(self, make_node_key, node_conversion_map):
"""Serializes `Node` for Functional API's `get_config`."""
# Serialization still special-cases first argument.
args, kwargs = self.call_args, self.call_kwargs
inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs)
# Treat everything other than first argument as a kwarg.
arguments = dict(zip(self.layer._call_fn_args[1:], args))
arguments.update(kwargs)
kwargs = arguments
kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
try:
json.dumps(kwargs, default=serialization.get_json_type)
except TypeError:
kwarg_types = nest.map_structure(type, kwargs)
logging.warning('Layer ' + self.layer.name +
' was passed non-JSON-serializable arguments. ' +
'Arguments had types: ' +
str(kwarg_types) + '. They will not be included '
'in the serialized model (and thus will be missing '
'at deserialization time).')
kwargs = {}
# `kwargs` is added to each Tensor in the first arg. This should be
# changed in a future version of the serialization format.
def serialize_first_arg_tensor(t):
kh = t._keras_history
node_index = kh.node_index
node_key = make_node_key(kh.layer.name, node_index)
new_node_index = node_conversion_map.get(node_key, 0)
data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
return tf_utils.ListWrapper(data)
data = nest.map_structure(serialize_first_arg_tensor, inputs)
if not nest.is_sequence(data):
data = [data]
data = tf_utils.convert_inner_node_data(data)
return data
#############################################################
# Properties for Backwards compatibility.
# These only check the first input argument
# As nodes are internal, they may be removed in the future.
#############################################################
@property
def input_tensors(self):
if self.is_input:
return [self.outputs] # Used in `Layer.input`.
return self.call_args[0]
@property
def output_tensors(self):
if self.is_input:
return [self.outputs] # Used in `Layer.input`.
return self.outputs
@property
def input_shapes(self):
input_shapes = nest.map_structure(backend.int_shape, self.input_tensors)
if len(input_shapes) == 1 and not self.is_input:
return input_shapes[0]
return input_shapes
@property
def output_shapes(self):
return nest.map_structure(backend.int_shape, self.output_tensors)
@property
def outbound_layer(self):
return self.layer
@property
def inbound_layers(self):
if self.is_input:
return []
inbound_layers = nest.map_structure(lambda t: t._keras_history.layer,
self.call_args[0])
return inbound_layers
class KerasHistory(
collections.namedtuple('KerasHistory',
['layer', 'node_index', 'tensor_index'])):
"""Tracks the Layer call that created a Tensor, for Keras Graph Networks.
During construction of Keras Graph Networks, this metadata is added to
each Tensor produced as the output of a Layer, starting with an
`InputLayer`. This allows Keras to track how each Tensor was produced, and
this information is later retraced by the `keras.engine.Network` class to
reconstruct the Keras Graph Network.
Attributes:
layer: The Layer that produced the Tensor.
node_index: The specific call to the Layer that produced this Tensor. Layers
can be called multiple times in order to share weights. A new node is
created every time a Layer is called.
tensor_index: The output index for this Tensor. Always zero if the Layer
that produced this Tensor only has one output. Nested structures of
Tensors are deterministically assigned an index via `nest.flatten`.
"""
# Added to maintain memory and performance characteristics of `namedtuple`
# while subclassing.
__slots__ = ()
def is_keras_tensor(obj):
return hasattr(obj, '_keras_history')
def _serialize_keras_tensor(t):
"""Serializes a single Tensor passed to `call`."""
if hasattr(t, '_keras_history'):
kh = t._keras_history
return [kh.layer.name, kh.node_index, kh.tensor_index]
if isinstance(t, np.ndarray):
return t.tolist()
if isinstance(t, ops.Tensor):
return backend.get_value(t).tolist()
return t

View File

@ -0,0 +1,160 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#,============================================================================
"""Tests for layer graphs construction & handling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import node as node_module
from tensorflow.python.platform import test
class DummyTensor(object):
def __init__(self, shape=None):
self.shape = shape
class DummyLayer(base_layer.Layer):
pass
class NetworkConstructionTest(keras_parameterized.TestCase):
def test_chained_node_construction(self):
# test basics
a = DummyTensor(shape=(None, 32))
b = DummyTensor(shape=(None, 32))
a_layer = DummyLayer()
node = node_module.Node(a_layer, outputs=a)
self.assertEqual(node.outbound_layer, a_layer)
self.assertTrue(node.is_input)
self.assertListEqual(node.inbound_layers, [])
self.assertListEqual(node.input_tensors, [a])
self.assertListEqual(node.input_shapes, [(None, 32)])
self.assertListEqual(node.output_tensors, [a])
self.assertListEqual(node.output_shapes, [(None, 32)])
b_layer = DummyLayer()
node_module.Node(b_layer, outputs=b)
dense = DummyLayer()
a_2 = DummyTensor()
node_a = node_module.Node(layer=dense, call_args=(a,), outputs=a_2)
b_2 = DummyTensor()
node_b = node_module.Node(layer=dense, call_args=(b,), outputs=b_2)
# test the node attributes
self.assertFalse(node_a.is_input)
self.assertFalse(node_b.is_input)
self.assertEqual(node_a.call_args, (a,))
self.assertEqual(node_a.call_kwargs, {})
self.assertEqual(node_a.outputs, a_2)
# Test the layer wiring
self.assertLen(dense._inbound_nodes, 2)
self.assertLen(dense._outbound_nodes, 0)
self.assertEqual(dense._inbound_nodes, [node_a, node_b])
self.assertEqual(dense._inbound_nodes[0].inbound_layers, a_layer)
self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense)
self.assertEqual(dense._inbound_nodes[1].inbound_layers, b_layer)
self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense)
self.assertIs(dense._inbound_nodes[0].input_tensors, a)
self.assertIs(dense._inbound_nodes[1].input_tensors, b)
def test_multi_input_node(self):
# test multi-input layer
a = DummyTensor()
b = DummyTensor()
dense = DummyLayer()
a_2 = DummyTensor()
node_module.Node(layer=dense, call_args=(a,), outputs=a_2)
b_2 = DummyTensor()
node_module.Node(layer=dense, call_args=(b,), outputs=b_2)
concat_layer = DummyLayer()
merged = DummyTensor()
node_module.Node(layer=concat_layer, call_args=([a_2, b_2],),
outputs=merged)
merge_layer, merge_node_index, merge_tensor_index = merged._keras_history
self.assertEqual(merge_node_index, 0)
self.assertEqual(merge_tensor_index, 0)
self.assertLen(merge_layer._inbound_nodes, 1)
self.assertLen(merge_layer._outbound_nodes, 0)
self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2)
self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a_2, b_2])
self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2)
def test_arg_and_kwarg_mix(self):
input_layer = DummyLayer()
input_layer_2 = DummyLayer()
a = DummyTensor()
node_a = node_module.Node(layer=input_layer, outputs=a)
b = DummyTensor()
node_b = node_module.Node(layer=input_layer_2, outputs=b)
arg_2 = DummyTensor()
arg_3 = DummyTensor()
node_c = node_module.Node(layer=input_layer, outputs=arg_3)
kwarg_x = DummyTensor()
kwarg_y = DummyTensor()
node_d = node_module.Node(layer=input_layer, outputs=kwarg_y)
merge_layer = DummyLayer()
merged = DummyTensor()
node = node_module.Node(layer=merge_layer,
call_args=([a, b], arg_2, arg_3),
call_kwargs={'x': kwarg_x, 'y': kwarg_y},
outputs=merged)
merge_layer, merge_node_index, merge_tensor_index = merged._keras_history
# Check the saved call args/kwargs
self.assertEqual(([a, b], arg_2, arg_3), node.call_args)
self.assertEqual({'x': kwarg_x, 'y': kwarg_y}, node.call_kwargs)
# Only the inputs that were produced by input nodes should appear in
# keras_tensors
self.assertEqual({a, b, arg_3, kwarg_y}, set(node.keras_inputs))
self.assertEqual(set(node.parent_nodes), {node_a, node_b, node_c, node_d})
# Check the layer wirings
self.assertEqual(merge_node_index, 0)
self.assertEqual(merge_tensor_index, 0)
self.assertLen(merge_layer._inbound_nodes, 1)
self.assertLen(merge_layer._outbound_nodes, 0)
self.assertLen(input_layer._outbound_nodes, 3)
self.assertLen(input_layer_2._outbound_nodes, 1)
# The 'backwards compatibility' attributes should only check the
# first call argument
self.assertLen(merge_layer._inbound_nodes[0].input_tensors, 2)
self.assertEqual(merge_layer._inbound_nodes[0].input_tensors, [a, b])
self.assertLen(merge_layer._inbound_nodes[0].inbound_layers, 2)
if __name__ == '__main__':
test.main()

View File

@ -215,7 +215,7 @@ class Sequential(training.Model):
set_inputs = True set_inputs = True
if set_inputs: if set_inputs:
outputs = nest.flatten(layer._inbound_nodes[-1].output_tensors) outputs = nest.flatten(layer._inbound_nodes[-1].outputs)
if len(outputs) != 1: if len(outputs) != 1:
raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
self.outputs = outputs self.outputs = outputs

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import contextlib
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
@ -302,11 +300,11 @@ class CorrectnessTest(keras_parameterized.TestCase):
self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4) self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4)
@parameterized.named_parameters([ @parameterized.named_parameters([
('_None', contextlib.contextmanager(lambda: iter([None])), 0., 4.), ('_None', None, 0., 4.),
('_0', lambda: keras.backend.learning_phase_scope(0), 4., 4.), ('_False', False, 4., 4.),
('_1', lambda: keras.backend.learning_phase_scope(1), 0., 0.), ('_True', True, 0., 0.),
]) ])
def test_nested_model_learning_phase(self, nested_scope_fn, def test_nested_model_learning_phase(self, training,
expected_training_loss, expected_training_loss,
expected_validation_loss): expected_validation_loss):
"""Tests that learning phase is correctly set in an intermediate layer.""" """Tests that learning phase is correctly set in an intermediate layer."""
@ -326,18 +324,17 @@ class CorrectnessTest(keras_parameterized.TestCase):
return keras.Model(inputs, outputs) return keras.Model(inputs, outputs)
def _regularize_model(unregularized_model): def _regularize_model(unregularized_model):
inputs = keras.Input(unregularized_model.inputs[0].shape[1:])
with nested_scope_fn():
logits = unregularized_model(inputs)
outputs = keras.activations.softmax(logits)
model = keras.Model(inputs, outputs)
# Regularize the most recent activations of a post-dropout layer. # Regularize the most recent activations of a post-dropout layer.
sample_activations = unregularized_model.get_layer( sample_activations = unregularized_model.get_layer(
index=-2).get_output_at(-1) index=-2).get_output_at(-1)
regularization_loss = keras.backend.mean(sample_activations) regularization_loss = keras.backend.mean(sample_activations)
model.add_loss(regularization_loss) unregularized_model.add_loss(regularization_loss)
model.add_metric( unregularized_model.add_metric(
regularization_loss, aggregation='mean', name='regularization_loss') regularization_loss, aggregation='mean', name='regularization_loss')
inputs = keras.Input(unregularized_model.inputs[0].shape[1:])
logits = unregularized_model(inputs, training=training)
outputs = keras.activations.softmax(logits)
model = keras.Model(inputs, outputs)
return model return model
# Make and compile models. # Make and compile models.

View File

@ -112,11 +112,12 @@ def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map):
# then call node.inbound_layer on them. # then call node.inbound_layer on them.
if all( if all(
tensor in tensor_map for tensor in nest.flatten(node.input_tensors)): tensor in tensor_map for tensor in nest.flatten(node.input_tensors)):
computed_tensors = nest.map_structure(lambda t: tensor_map[t],
node.input_tensors)
# Call layer. # Call layer.
kwargs = node.arguments or {} args = nest.map_structure(lambda t: tensor_map.get(t, t),
output_tensors = layer(computed_tensors, **kwargs) node.call_args)
kwargs = nest.map_structure(lambda t: tensor_map.get(t, t),
node.call_kwargs)
output_tensors = layer(*args, **kwargs)
# Thread-safe way to keep track of what node was created. # Thread-safe way to keep track of what node was created.
first_output_tensor = nest.flatten(output_tensors)[0] first_output_tensor = nest.flatten(output_tensors)[0]

View File

@ -547,7 +547,7 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
else: else:
return inputs return inputs
t = array_ops.sequence_mask(1) t = self.evaluate(array_ops.sequence_mask(1))
inputs = keras.layers.Input(shape=(3)) inputs = keras.layers.Input(shape=(3))
model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t)) model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t))

View File

@ -55,7 +55,7 @@ def get_source_inputs(tensor, layer=None, node_index=None):
return [tensor] return [tensor]
else: else:
node = layer._inbound_nodes[node_index] node = layer._inbound_nodes[node_index]
if not node.inbound_layers: if node.is_input:
# Reached an Input layer, stop recursion. # Reached an Input layer, stop recursion.
return nest.flatten(node.input_tensors) return nest.flatten(node.input_tensors)
else: else:
@ -140,7 +140,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
nodes = [] nodes = []
for v in nodes_by_depth: for v in nodes_by_depth:
if (len(v) > 1) or (len(v) == 1 and if (len(v) > 1) or (len(v) == 1 and
len(nest.flatten(v[0].inbound_layers)) > 1): len(nest.flatten(v[0].keras_inputs)) > 1):
# if the model has multiple nodes # if the model has multiple nodes
# or if the nodes have multiple inbound_layers # or if the nodes have multiple inbound_layers
# the model is no longer sequential # the model is no longer sequential