1317 lines
52 KiB
Python
1317 lines
52 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
# pylint: disable=protected-access
|
|
"""A `Network` is way to compose layers: the topological form of a `Model`.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import copy
|
|
import itertools
|
|
import warnings
|
|
|
|
from six.moves import zip # pylint: disable=redefined-builtin
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import composite_tensor
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.keras import backend
|
|
from tensorflow.python.keras.engine import base_layer
|
|
from tensorflow.python.keras.engine import base_layer_utils
|
|
from tensorflow.python.keras.engine import input_layer as input_layer_module
|
|
from tensorflow.python.keras.engine import keras_tensor
|
|
from tensorflow.python.keras.engine import node as node_module
|
|
from tensorflow.python.keras.engine import training as training_lib
|
|
from tensorflow.python.keras.engine import training_utils
|
|
from tensorflow.python.keras.saving.saved_model import network_serialization
|
|
from tensorflow.python.keras.utils import generic_utils
|
|
from tensorflow.python.keras.utils import tf_utils
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.training.tracking import base as trackable
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
|
|
# pylint: disable=g-classes-have-attributes
|
|
class Functional(training_lib.Model):
|
|
"""A `Functional` model is a `Model` defined as a directed graph of layers.
|
|
|
|
Three types of `Model` exist: subclassed `Model`, `Functional` model,
|
|
and `Sequential` (a special case of `Functional`).
|
|
In general, more Keras features are supported with `Functional`
|
|
than with subclassed `Model`s, specifically:
|
|
|
|
- Model cloning (`keras.models.clone`)
|
|
- Serialization (`model.get_config()/from_config`, `model.to_json()/to_yaml()`
|
|
- Whole-model saving (`model.save()`)
|
|
|
|
A `Functional` model can be instantiated by passing two arguments to
|
|
`__init__`. The first argument is the `keras.Input` Tensors that represent
|
|
the inputs to the model. The second argument specifies the output
|
|
tensors that represent the outputs of this model. Both arguments can be a
|
|
nested structure of tensors.
|
|
|
|
Example:
|
|
|
|
```
|
|
inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
|
|
t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
|
|
outputs = keras.layers.Add()([t, inputs['x2'])
|
|
model = keras.Model(inputs, outputs)
|
|
```
|
|
|
|
A `Functional` model constructed using the Functional API can also include raw
|
|
TensorFlow functions, with the exception of functions that create Variables
|
|
or assign ops.
|
|
|
|
Example:
|
|
|
|
```
|
|
inputs = keras.Input(shape=(10,))
|
|
x = keras.layers.Dense(1)(inputs)
|
|
outputs = tf.nn.relu(x)
|
|
model = keras.Model(inputs, outputs)
|
|
```
|
|
|
|
Arguments:
|
|
inputs: List of input tensors (must be created via `tf.keras.Input()`).
|
|
outputs: List of outputs tensors.
|
|
name: String, optional. Name of the model.
|
|
trainable: Boolean, whether the model's variables should be trainable.
|
|
"""
|
|
|
|
# See tf.Module for the usage of this property.
|
|
# The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail to
|
|
# flatten the key since it is trying to convert Trackable/Layer to a string.
|
|
_TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
|
|
('_layer_call_argspecs', '_compiled_trainable_state',
|
|
'_output_mask_cache', '_output_tensor_cache', '_output_shape_cache'),
|
|
training_lib.Model._TF_MODULE_IGNORED_PROPERTIES
|
|
))
|
|
|
|
@trackable.no_automatic_dependency_tracking
|
|
def __init__(self, inputs=None, outputs=None, name=None, trainable=True):
|
|
# generic_utils.validate_kwargs(
|
|
# kwargs, {'name', 'trainable'},
|
|
# 'Functional models may only specify `name` and `trainable` keyword '
|
|
# 'arguments during initialization. Got an unexpected argument:')
|
|
super(Functional, self).__init__(name=name, trainable=trainable)
|
|
self._init_graph_network(inputs, outputs)
|
|
|
|
@trackable.no_automatic_dependency_tracking
|
|
def _init_graph_network(self, inputs, outputs):
|
|
# This method is needed for Sequential to reinitialize graph network when
|
|
# layer is added or removed.
|
|
self._is_graph_network = True
|
|
|
|
# Normalize and set self.inputs, self.outputs.
|
|
if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1:
|
|
inputs = inputs[0]
|
|
if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1:
|
|
outputs = outputs[0]
|
|
self._nested_inputs = inputs
|
|
self._nested_outputs = outputs
|
|
self.inputs = nest.flatten(inputs)
|
|
self.outputs = nest.flatten(outputs)
|
|
|
|
# Models constructed with a single Tensor or list of Tensors can
|
|
# be called with a dict, where the keys of the dict are the names
|
|
# of the `Input` objects. Extra keys are ignored with warning.
|
|
self._enable_dict_to_input_mapping = (
|
|
not nest.is_sequence(self._nested_inputs) or
|
|
(isinstance(self._nested_inputs, (list, tuple, dict)) and
|
|
not any(nest.is_sequence(t) for t in self._nested_inputs)))
|
|
|
|
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
|
base_layer_utils.create_keras_history(self._nested_outputs)
|
|
|
|
self._validate_graph_inputs_and_outputs()
|
|
|
|
# A Network does not create weights of its own, thus it is already
|
|
# built.
|
|
self.built = True
|
|
self._build_input_shape = nest.map_structure(lambda x: x.shape, inputs)
|
|
self._compute_output_and_mask_jointly = True
|
|
# `_expects_training_arg` is True since the `training` argument is always
|
|
# present in the signature of the `call` method of a graph network.
|
|
self._expects_training_arg = True
|
|
self._expects_mask_arg = True
|
|
# A graph network does not autocast inputs, as its layers will cast them
|
|
# instead.
|
|
self._autocast = False
|
|
|
|
self._input_layers = []
|
|
self._output_layers = []
|
|
self._input_coordinates = []
|
|
self._output_coordinates = []
|
|
|
|
# This is for performance optimization when calling the Network on new
|
|
# inputs. Every time the Network is called on a set on input tensors,
|
|
# we compute the output tensors, output masks and output shapes in one pass,
|
|
# then cache them here. When any of these outputs is queried later, we
|
|
# retrieve it from there instead of recomputing it.
|
|
self._output_mask_cache = {}
|
|
self._output_tensor_cache = {}
|
|
self._output_shape_cache = {}
|
|
|
|
# Build self._output_layers:
|
|
for x in self.outputs:
|
|
layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
|
|
self._output_layers.append(layer)
|
|
self._output_coordinates.append((layer, node_index, tensor_index))
|
|
|
|
# Build self._input_layers:
|
|
for x in self.inputs:
|
|
layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
|
|
# It's supposed to be an input layer, so only one node
|
|
# and one tensor output.
|
|
assert node_index == 0
|
|
assert tensor_index == 0
|
|
self._input_layers.append(layer)
|
|
self._input_coordinates.append((layer, node_index, tensor_index))
|
|
|
|
# Keep track of the network's nodes and layers.
|
|
nodes, nodes_by_depth, layers, _ = _map_graph_network(
|
|
self.inputs, self.outputs)
|
|
self._network_nodes = nodes
|
|
self._nodes_by_depth = nodes_by_depth
|
|
self._layers = layers
|
|
self._layer_call_argspecs = {}
|
|
for layer in self._layers:
|
|
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
|
|
|
# Build self.input_names and self.output_names.
|
|
self._set_output_names()
|
|
self.input_names = []
|
|
self._feed_input_names = []
|
|
self._feed_inputs = []
|
|
self._feed_input_shapes = []
|
|
for layer in self._input_layers:
|
|
self.input_names.append(layer.name)
|
|
if layer.is_placeholder:
|
|
self._feed_input_names.append(layer.name)
|
|
# Use batch_input_shape here because non-eager composite tensors may not
|
|
# have a shape attribute that's meaningful (sparse, for instance, has
|
|
# a tensor that's non-constant and needs to be fed). This means that
|
|
# input layers that create placeholders will need to have the
|
|
# batch_input_shape attr to allow for input shape validation.
|
|
self._feed_input_shapes.append(layer._batch_input_shape)
|
|
self._feed_inputs.append(layer.input)
|
|
|
|
self._compute_tensor_usage_count()
|
|
self._set_save_spec(self._nested_inputs)
|
|
tf_utils.assert_no_legacy_layers(self.layers)
|
|
|
|
@property
|
|
def input(self):
|
|
"""Retrieves the input tensor(s) of a layer.
|
|
|
|
Only applicable if the layer has exactly one input,
|
|
i.e. if it is connected to one incoming layer.
|
|
|
|
Returns:
|
|
Input tensor or list of input tensors.
|
|
|
|
Raises:
|
|
RuntimeError: If called in Eager mode.
|
|
AttributeError: If no inbound nodes are found.
|
|
"""
|
|
return self._nested_inputs
|
|
|
|
@property
|
|
def input_shape(self):
|
|
"""Retrieves the input shape(s) of a layer.
|
|
|
|
Only applicable if the layer has exactly one input,
|
|
i.e. if it is connected to one incoming layer, or if all inputs
|
|
have the same shape.
|
|
|
|
Returns:
|
|
Input shape, as an integer shape tuple
|
|
(or list of shape tuples, one tuple per input tensor).
|
|
|
|
Raises:
|
|
AttributeError: if the layer has no defined input_shape.
|
|
RuntimeError: if called in Eager mode.
|
|
"""
|
|
return nest.map_structure(backend.int_shape, self.input)
|
|
|
|
@property
|
|
def output(self):
|
|
"""Retrieves the output tensor(s) of a layer.
|
|
|
|
Only applicable if the layer has exactly one output,
|
|
i.e. if it is connected to one incoming layer.
|
|
|
|
Returns:
|
|
Output tensor or list of output tensors.
|
|
|
|
Raises:
|
|
AttributeError: if the layer is connected to more than one incoming
|
|
layers.
|
|
RuntimeError: if called in Eager mode.
|
|
"""
|
|
return self._nested_outputs
|
|
|
|
@property
|
|
def output_shape(self):
|
|
"""Retrieves the output shape(s) of a layer.
|
|
|
|
Only applicable if the layer has one output,
|
|
or if all outputs have the same shape.
|
|
|
|
Returns:
|
|
Output shape, as an integer shape tuple
|
|
(or list of shape tuples, one tuple per output tensor).
|
|
|
|
Raises:
|
|
AttributeError: if the layer has no defined output shape.
|
|
RuntimeError: if called in Eager mode.
|
|
"""
|
|
return nest.map_structure(backend.int_shape, self.output)
|
|
|
|
def _set_output_names(self):
|
|
"""Assigns unique names to the Network's outputs.
|
|
|
|
Output layers with multiple output tensors would otherwise lead to duplicate
|
|
names in self.output_names.
|
|
"""
|
|
uniquified = []
|
|
output_names = set()
|
|
prefix_count = {}
|
|
for layer in self._output_layers:
|
|
proposal = layer.name
|
|
while proposal in output_names:
|
|
existing_count = prefix_count.get(layer.name, 1)
|
|
proposal = '{}_{}'.format(layer.name, existing_count)
|
|
prefix_count[layer.name] = existing_count + 1
|
|
output_names.add(proposal)
|
|
uniquified.append(proposal)
|
|
self.output_names = uniquified
|
|
|
|
@property
|
|
def _layer_checkpoint_dependencies(self):
|
|
"""Dictionary of layer dependencies to be included in the checkpoint."""
|
|
weight_layer_index = 0
|
|
|
|
dependencies = collections.OrderedDict()
|
|
for layer_index, layer in enumerate(self.layers):
|
|
try:
|
|
if layer.weights:
|
|
# Keep a separate index for layers which have weights. This allows
|
|
# users to insert Layers without weights anywhere in the network
|
|
# without breaking checkpoints.
|
|
dependencies['layer_with_weights-%d' % weight_layer_index] = layer
|
|
weight_layer_index += 1
|
|
except ValueError:
|
|
# The layer might have weights, but may not be built yet. We just treat
|
|
# it as layer without weight.
|
|
pass
|
|
|
|
# Even if it doesn't have weights, we should still track everything in
|
|
# case it has/will have Trackable dependencies.
|
|
dependencies['layer-%d' % layer_index] = layer
|
|
return dependencies
|
|
|
|
@property
|
|
def _checkpoint_dependencies(self):
|
|
dependencies = [
|
|
trackable.TrackableReference(name=name, ref=layer)
|
|
for name, layer in self._layer_checkpoint_dependencies.items()]
|
|
dependencies.extend(super(Functional, self)._checkpoint_dependencies)
|
|
return dependencies
|
|
|
|
def _lookup_dependency(self, name):
|
|
layer_dependencies = self._layer_checkpoint_dependencies
|
|
if name in layer_dependencies:
|
|
return layer_dependencies[name]
|
|
return super(Functional, self)._lookup_dependency(name)
|
|
|
|
def _handle_deferred_layer_dependencies(self, layers):
|
|
"""Handles layer checkpoint dependencies that are added after init."""
|
|
layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
|
|
layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
|
|
for layer in layers:
|
|
if layer in layer_to_name:
|
|
self._handle_deferred_dependencies(name=layer_to_name[layer],
|
|
trackable=layer)
|
|
|
|
@property
|
|
def _should_compute_mask(self):
|
|
return True
|
|
|
|
def compute_mask(self, inputs, mask):
|
|
# TODO(omalleyt): b/123540974 This function is not really safe to call
|
|
# by itself because it will duplicate any updates and losses in graph
|
|
# mode by `call`ing the Layers again.
|
|
output_tensors = self._run_internal_graph(inputs, mask=mask)
|
|
return nest.map_structure(lambda t: getattr(t, '_keras_mask', None),
|
|
output_tensors)
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
"""Calls the model on new inputs.
|
|
|
|
In this case `call` just reapplies
|
|
all ops in the graph to the new inputs
|
|
(e.g. build a new computational graph from the provided inputs).
|
|
|
|
Arguments:
|
|
inputs: A tensor or list of tensors.
|
|
training: Boolean or boolean scalar tensor, indicating whether to run
|
|
the `Network` in training mode or inference mode.
|
|
mask: A mask or list of masks. A mask can be
|
|
either a tensor or None (no mask).
|
|
|
|
Returns:
|
|
A tensor if there is a single output, or
|
|
a list of tensors if there are more than one outputs.
|
|
"""
|
|
return self._run_internal_graph(
|
|
inputs, training=training, mask=mask)
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
# Convert any shapes in tuple format to TensorShapes.
|
|
input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
|
|
|
|
if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)):
|
|
raise ValueError('Invalid input_shape argument ' + str(input_shape) +
|
|
': model has ' + str(len(self._input_layers)) +
|
|
' tensor inputs.')
|
|
|
|
# Use the tuple of TensorShape as the cache key, since tuple is hashable
|
|
# and can be used as hash key.
|
|
try:
|
|
cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True))
|
|
if cache_key in self._output_shape_cache:
|
|
# Cache hit. Return shapes as TensorShapes.
|
|
return self._output_shape_cache[cache_key]
|
|
except ValueError:
|
|
# In case there are unknown TensorShape, eg for sparse tensor input,
|
|
# We skip the caching since the shape is unknown.
|
|
pass
|
|
|
|
layers_to_output_shapes = {}
|
|
for layer, shape in zip(self._input_layers, nest.flatten(input_shape)):
|
|
# It's an input layer: then `compute_output_shape` is identity,
|
|
# and there is only one node and one tensor..
|
|
shape_key = layer.name + '_0_0'
|
|
layers_to_output_shapes[shape_key] = shape
|
|
|
|
depth_keys = list(self._nodes_by_depth.keys())
|
|
depth_keys.sort(reverse=True)
|
|
# Iterate over nodes, by depth level.
|
|
if len(depth_keys) > 1:
|
|
for depth in depth_keys:
|
|
nodes = self._nodes_by_depth[depth]
|
|
for node in nodes:
|
|
layer = node.layer
|
|
if layer in self._input_layers:
|
|
# We've already covered the input layers
|
|
# a few lines above.
|
|
continue
|
|
# Get the input shapes for the first argument of the node
|
|
layer_input_shapes = []
|
|
layer_inputs = node.call_args[0]
|
|
for layer_input in nest.flatten(layer_inputs):
|
|
kh = layer_input._keras_history
|
|
input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index,
|
|
kh.tensor_index)
|
|
layer_input_shapes.append(layers_to_output_shapes[input_layer_key])
|
|
layer_input_shapes = nest.pack_sequence_as(layer_inputs,
|
|
layer_input_shapes)
|
|
# Layers expect shapes to be tuples for `compute_output_shape`.
|
|
layer_input_shapes = tf_utils.convert_shapes(
|
|
layer_input_shapes, to_tuples=True)
|
|
layer_output_shapes = layer.compute_output_shape(layer_input_shapes)
|
|
# Convert back to TensorShapes.
|
|
layer_output_shapes = tf_utils.convert_shapes(
|
|
layer_output_shapes, to_tuples=False)
|
|
|
|
node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access
|
|
for j, shape in enumerate(nest.flatten(layer_output_shapes)):
|
|
shape_key = layer.name + '_%s_%s' % (node_index, j)
|
|
layers_to_output_shapes[shape_key] = shape
|
|
|
|
# Read final output shapes from layers_to_output_shapes.
|
|
output_shapes = []
|
|
for i in range(len(self._output_layers)):
|
|
layer, node_index, tensor_index = self._output_coordinates[i]
|
|
shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
|
|
output_shapes.append(layers_to_output_shapes[shape_key])
|
|
output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes)
|
|
# Store in cache.
|
|
self._output_shape_cache[cache_key] = output_shapes
|
|
|
|
# Return shapes as TensorShapes.
|
|
return output_shapes
|
|
|
|
def _run_internal_graph(self, inputs, training=None, mask=None):
|
|
"""Computes output tensors for new inputs.
|
|
|
|
# Note:
|
|
- Can be run on non-Keras tensors.
|
|
|
|
Arguments:
|
|
inputs: Tensor or nested structure of Tensors.
|
|
training: Boolean learning phase.
|
|
mask: (Optional) Tensor or nested structure of Tensors.
|
|
|
|
Returns:
|
|
output_tensors
|
|
"""
|
|
inputs = self._flatten_to_reference_inputs(inputs)
|
|
if mask is None:
|
|
masks = [None] * len(inputs)
|
|
else:
|
|
masks = self._flatten_to_reference_inputs(mask)
|
|
for input_t, mask in zip(inputs, masks):
|
|
input_t._keras_mask = mask
|
|
|
|
# Dictionary mapping reference tensors to computed tensors.
|
|
tensor_dict = {}
|
|
tensor_usage_count = self._tensor_usage_count
|
|
for x, y in zip(self.inputs, inputs):
|
|
y = self._conform_to_reference_input(y, ref_input=x)
|
|
x_id = str(id(x))
|
|
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
|
|
|
nodes_by_depth = self._nodes_by_depth
|
|
depth_keys = list(nodes_by_depth.keys())
|
|
depth_keys.sort(reverse=True)
|
|
|
|
for depth in depth_keys:
|
|
nodes = nodes_by_depth[depth]
|
|
for node in nodes:
|
|
if node.is_input:
|
|
continue # Input tensors already exist.
|
|
|
|
if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
|
|
continue # Node is not computable, try skipping.
|
|
|
|
args, kwargs = node.map_arguments(tensor_dict)
|
|
outputs = node.layer(*args, **kwargs)
|
|
|
|
# Update tensor_dict.
|
|
for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)):
|
|
tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
|
|
|
|
output_tensors = []
|
|
for x in self.outputs:
|
|
x_id = str(id(x))
|
|
assert x_id in tensor_dict, 'Could not compute output ' + str(x)
|
|
output_tensors.append(tensor_dict[x_id].pop())
|
|
|
|
return nest.pack_sequence_as(self._nested_outputs, output_tensors)
|
|
|
|
def _flatten_to_reference_inputs(self, tensors):
|
|
"""Maps `tensors` to their respective `keras.Input`."""
|
|
if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
|
|
ref_inputs = self._nested_inputs
|
|
if not nest.is_sequence(ref_inputs):
|
|
ref_inputs = [self._nested_inputs]
|
|
if isinstance(ref_inputs, dict):
|
|
# In the case that the graph is constructed with dict input tensors,
|
|
# We will use the original dict key to map with the keys in the input
|
|
# data. Note that the model.inputs is using nest.flatten to process the
|
|
# input tensors, which means the dict input tensors are ordered by their
|
|
# keys.
|
|
ref_input_names = sorted(ref_inputs.keys())
|
|
else:
|
|
ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs]
|
|
|
|
# Raise an warning if there are more input data comparing to input tensor
|
|
if len(tensors) > len(ref_input_names):
|
|
warnings.warn(
|
|
'Input dict contained keys {} which did not match any model input. '
|
|
'They will be ignored by the model.'.format(
|
|
[n for n in tensors.keys() if n not in ref_input_names])
|
|
)
|
|
|
|
try:
|
|
# Flatten in the order `Input`s were passed during Model construction.
|
|
return [tensors[n] for n in ref_input_names]
|
|
except KeyError:
|
|
# TODO(b/151582614)
|
|
return nest.flatten(tensors)
|
|
|
|
# Otherwise both self.inputs and tensors will already be in same order.
|
|
return nest.flatten(tensors)
|
|
|
|
def _conform_to_reference_input(self, tensor, ref_input):
|
|
"""Set shape and dtype based on `keras.Input`s."""
|
|
if isinstance(tensor, ops.Tensor):
|
|
# Allow (None,) and (None, 1) Tensors to be passed interchangably. Use the
|
|
# shape specified by the `keras.Input`.
|
|
t_shape = tensor.shape
|
|
t_rank = t_shape.rank
|
|
ref_shape = ref_input.shape
|
|
ref_rank = ref_shape.rank
|
|
keras_history = getattr(tensor, '_keras_history', None)
|
|
if t_rank is not None and ref_rank is not None:
|
|
# Should squeeze last dimension.
|
|
# True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...).
|
|
if (t_rank == ref_rank + 1 and t_shape[-1] == 1):
|
|
tensor = array_ops.squeeze_v2(tensor, axis=-1)
|
|
# Should expand last_dimension.
|
|
# True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1).
|
|
elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1):
|
|
tensor = array_ops.expand_dims_v2(tensor, axis=-1)
|
|
if keras_history is not None: # Restore keras history.
|
|
tensor._keras_history = keras_history
|
|
|
|
# Add shape hints to Tensors that may have None shape dims but have shapes
|
|
# defined by the `keras.Input` (not applicable in eager mode).
|
|
if not context.executing_eagerly():
|
|
try:
|
|
tensor.set_shape(tensor.shape.merge_with(ref_input.shape))
|
|
except ValueError:
|
|
logging.warning(
|
|
'Model was constructed with shape {} for input {}, but it was '
|
|
'called on an input with incompatible shape {}.'.format(
|
|
ref_input.shape, ref_input, tensor.shape))
|
|
|
|
# Dtype casting.
|
|
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
|
elif isinstance(tensor, composite_tensor.CompositeTensor):
|
|
# Dtype casting.
|
|
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
|
|
|
return tensor
|
|
|
|
def get_config(self):
|
|
return copy.deepcopy(get_network_config(self))
|
|
|
|
@classmethod
|
|
def from_config(cls, config, custom_objects=None):
|
|
"""Instantiates a Model from its config (output of `get_config()`).
|
|
|
|
Arguments:
|
|
config: Model config dictionary.
|
|
custom_objects: Optional dictionary mapping names
|
|
(strings) to custom classes or functions to be
|
|
considered during deserialization.
|
|
|
|
Returns:
|
|
A model instance.
|
|
|
|
Raises:
|
|
ValueError: In case of improperly formatted config dict.
|
|
"""
|
|
input_tensors, output_tensors, created_layers = reconstruct_from_config(
|
|
config, custom_objects)
|
|
model = cls(inputs=input_tensors, outputs=output_tensors,
|
|
name=config.get('name'))
|
|
connect_ancillary_layers(model, created_layers)
|
|
return model
|
|
|
|
def _validate_graph_inputs_and_outputs(self):
|
|
"""Validates the inputs and outputs of a Graph Network."""
|
|
# Check for redundancy in inputs.
|
|
if len({id(i) for i in self.inputs}) != len(self.inputs):
|
|
raise ValueError('The list of inputs passed to the model '
|
|
'is redundant. '
|
|
'All inputs should only appear once.'
|
|
' Found: ' + str(self.inputs))
|
|
|
|
for x in self.inputs:
|
|
# Check that x has appropriate `_keras_history` metadata.
|
|
if not hasattr(x, '_keras_history'):
|
|
cls_name = self.__class__.__name__
|
|
raise ValueError('Input tensors to a ' + cls_name + ' ' +
|
|
'must come from `tf.keras.Input`. '
|
|
'Received: ' + str(x) +
|
|
' (missing previous layer metadata).')
|
|
# Check that x is an input tensor.
|
|
# pylint: disable=protected-access
|
|
layer = x._keras_history.layer
|
|
if len(layer._inbound_nodes) > 1 or (
|
|
layer._inbound_nodes and not layer._inbound_nodes[0].is_input):
|
|
cls_name = self.__class__.__name__
|
|
logging.warning(cls_name + ' inputs must come from '
|
|
'`tf.keras.Input` (thus holding past layer metadata), '
|
|
'they cannot be the output of '
|
|
'a previous non-Input layer. '
|
|
'Here, a tensor specified as '
|
|
'input to "' + self.name + '" was not an Input tensor, '
|
|
'it was generated by layer ' + layer.name + '.\n'
|
|
'Note that input tensors are '
|
|
'instantiated via `tensor = tf.keras.Input(shape)`.\n'
|
|
'The tensor that caused the issue was: ' + str(x.name))
|
|
|
|
# Check compatibility of batch sizes of Input Layers.
|
|
input_batch_sizes = [
|
|
training_utils.get_static_batch_size(x._keras_history.layer)
|
|
for x in self.inputs
|
|
]
|
|
consistent_batch_size = None
|
|
for batch_size in input_batch_sizes:
|
|
if batch_size is not None:
|
|
if (consistent_batch_size is not None and
|
|
batch_size != consistent_batch_size):
|
|
raise ValueError('The specified batch sizes of the Input Layers'
|
|
' are incompatible. Found batch sizes: {}'.format(
|
|
input_batch_sizes))
|
|
consistent_batch_size = batch_size
|
|
|
|
for x in self.outputs:
|
|
if not hasattr(x, '_keras_history'):
|
|
cls_name = self.__class__.__name__
|
|
raise ValueError('Output tensors to a ' + cls_name + ' must be '
|
|
'the output of a TensorFlow `Layer` '
|
|
'(thus holding past layer metadata). Found: ' + str(x))
|
|
|
|
def _insert_layers(self, layers, relevant_nodes=None):
|
|
"""Inserts Layers into the Network after Network creation.
|
|
|
|
This is only valid for Keras Graph Networks. Layers added via this function
|
|
will be included in the `call` computation and `get_config` of this Network.
|
|
They will not be added to the Network's outputs.
|
|
|
|
|
|
Arguments:
|
|
layers: Arbitrary nested structure of Layers. Layers must be reachable
|
|
from one or more of the `keras.Input` Tensors that correspond to this
|
|
Network's inputs.
|
|
relevant_nodes: Nodes from the Layers that should be considered part of
|
|
this Network. If `None`, all Nodes will be considered part of this
|
|
Network.
|
|
|
|
Raises:
|
|
ValueError: If the layers depend on `Input`s not found in this Model.
|
|
"""
|
|
layers = nest.flatten(layers)
|
|
tf_utils.assert_no_legacy_layers(layers)
|
|
node_to_depth = {}
|
|
for depth, nodes in self._nodes_by_depth.items():
|
|
node_to_depth.update({node: depth for node in nodes})
|
|
# The nodes of these Layers that are relevant to this Network. If not
|
|
# provided, assume all Nodes are relevant
|
|
if not relevant_nodes:
|
|
relevant_nodes = nest.flatten([layer._inbound_nodes for layer in layers])
|
|
network_nodes = set(relevant_nodes + list(node_to_depth.keys()))
|
|
|
|
def _get_min_depth(node):
|
|
"""Gets the minimum depth at which node can be computed."""
|
|
min_depth = 0
|
|
for layer, node_id, _, _ in node.iterate_inbound():
|
|
inbound_node = layer._inbound_nodes[node_id]
|
|
if inbound_node in node_to_depth:
|
|
min_depth = min(min_depth, node_to_depth[inbound_node])
|
|
elif inbound_node not in network_nodes:
|
|
continue
|
|
else:
|
|
# Previous relevant nodes haven't been processed yet.
|
|
return None
|
|
# New node is one shallower than its shallowest input.
|
|
return min_depth - 1
|
|
|
|
# Insert nodes into `_nodes_by_depth` and other node attrs.
|
|
unprocessed_nodes = copy.copy(relevant_nodes)
|
|
i = 0
|
|
while unprocessed_nodes:
|
|
i += 1
|
|
# Do a sanity check. This can occur if `Input`s from outside this Model
|
|
# are being relied on.
|
|
if i > 10000:
|
|
raise ValueError('Layers could not be added due to missing '
|
|
'dependencies.')
|
|
|
|
node = unprocessed_nodes.pop(0)
|
|
depth = _get_min_depth(node)
|
|
if depth is None: # Defer until inbound nodes are processed.
|
|
unprocessed_nodes.append(node)
|
|
continue
|
|
node_key = _make_node_key(node.layer.name,
|
|
node.layer._inbound_nodes.index(node))
|
|
if node_key not in self._network_nodes:
|
|
node_to_depth[node] = depth
|
|
self._network_nodes.add(node_key)
|
|
self._nodes_by_depth[depth].append(node)
|
|
|
|
# Insert layers and update other layer attrs.
|
|
layer_set = set(self._layers)
|
|
deferred_layers = []
|
|
for layer in layers:
|
|
if layer not in layer_set:
|
|
self._layers.append(layer)
|
|
deferred_layers.append(layer)
|
|
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
|
|
layer_set.add(layer)
|
|
self._handle_deferred_layer_dependencies(deferred_layers)
|
|
|
|
self._compute_tensor_usage_count()
|
|
|
|
def _compute_tensor_usage_count(self):
|
|
"""Compute the #. of tensor usages for all the output tensors of layers.
|
|
|
|
The computed tensor usage count is saved as `self._tensor_usage_count`. This
|
|
is later used for saving memory in eager computation by releasing
|
|
no-longer-needed tensors as early as possible.
|
|
"""
|
|
tensor_usage_count = collections.Counter()
|
|
available_tensors = set(str(id(tensor)) for tensor in self.inputs)
|
|
|
|
depth_keys = list(self._nodes_by_depth.keys())
|
|
depth_keys.sort(reverse=True)
|
|
depth_keys = depth_keys[1:]
|
|
|
|
for depth in depth_keys:
|
|
for node in self._nodes_by_depth[depth]:
|
|
input_tensors = {
|
|
str(id(tensor)) for tensor in nest.flatten(node.keras_inputs)
|
|
}
|
|
if input_tensors.issubset(available_tensors):
|
|
for tensor in nest.flatten(node.keras_inputs):
|
|
tensor_usage_count[str(id(tensor))] += 1
|
|
|
|
for output_tensor in nest.flatten(node.outputs):
|
|
available_tensors.add(str(id(output_tensor)))
|
|
|
|
for tensor in self.outputs:
|
|
tensor_usage_count[str(id(tensor))] += 1
|
|
|
|
self._tensor_usage_count = tensor_usage_count
|
|
|
|
def _assert_weights_created(self):
|
|
# Override the implementation in Model.
|
|
# The Functional model should always have weight created already.
|
|
return
|
|
|
|
def _graph_network_add_loss(self, symbolic_loss):
|
|
new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
|
|
# Losses must be keyed on inputs no matter what in order to be supported in
|
|
# DistributionStrategy.
|
|
add_loss_layer = base_layer.AddLoss(
|
|
unconditional=False, dtype=symbolic_loss.dtype)
|
|
add_loss_layer(symbolic_loss)
|
|
new_nodes.extend(add_loss_layer.inbound_nodes)
|
|
new_layers.append(add_loss_layer)
|
|
self._insert_layers(new_layers, new_nodes)
|
|
|
|
def _graph_network_add_metric(self, value, aggregation, name):
|
|
new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
|
|
add_metric_layer = base_layer.AddMetric(
|
|
aggregation, name, dtype=value.dtype)
|
|
add_metric_layer(value)
|
|
new_nodes.extend(add_metric_layer.inbound_nodes)
|
|
new_layers.append(add_metric_layer)
|
|
self._insert_layers(new_layers, new_nodes)
|
|
|
|
@property
|
|
def _trackable_saved_model_saver(self):
|
|
return network_serialization.NetworkSavedModelSaver(self)
|
|
|
|
|
|
def _make_node_key(layer_name, node_index):
|
|
return layer_name + '_ib-' + str(node_index)
|
|
|
|
|
|
def _map_graph_network(inputs, outputs):
|
|
"""Validates a network's topology and gather its layers and nodes.
|
|
|
|
Arguments:
|
|
inputs: List of input tensors.
|
|
outputs: List of outputs tensors.
|
|
|
|
Returns:
|
|
A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
|
|
- nodes: list of Node instances.
|
|
- nodes_by_depth: dict mapping ints (depth) to lists of node instances.
|
|
- layers: list of Layer instances.
|
|
- layers_by_depth: dict mapping ints (depth) to lists of layer instances.
|
|
|
|
Raises:
|
|
ValueError: In case the network is not valid (e.g. disconnected graph).
|
|
"""
|
|
# "depth" is number of layers between output Node and the Node.
|
|
# Nodes are ordered from inputs -> outputs.
|
|
nodes_in_decreasing_depth, layer_indices = _build_map(outputs)
|
|
network_nodes = {
|
|
_make_node_key(node.layer.name, node.layer._inbound_nodes.index(node))
|
|
for node in nodes_in_decreasing_depth
|
|
}
|
|
|
|
nodes_depths = {} # dict {node: depth value}
|
|
layers_depths = {} # dict {layer: depth value}
|
|
|
|
for node in reversed(nodes_in_decreasing_depth):
|
|
# If the depth is not set, the node has no outbound nodes (depth 0).
|
|
depth = nodes_depths.setdefault(node, 0)
|
|
|
|
# Update the depth of the corresponding layer
|
|
previous_depth = layers_depths.get(node.layer, 0)
|
|
# If we've seen this layer before at a higher depth,
|
|
# we should use that depth instead of the node depth.
|
|
# This is necessary for shared layers that have inputs at different
|
|
# depth levels in the graph.
|
|
depth = max(depth, previous_depth)
|
|
layers_depths[node.layer] = depth
|
|
nodes_depths[node] = depth
|
|
|
|
# Update the depth of inbound nodes.
|
|
# The "depth" of a node is the max of the depths
|
|
# of all nodes it is connected to + 1.
|
|
for node_dep in node.parent_nodes:
|
|
previous_depth = nodes_depths.get(node_dep, 0)
|
|
nodes_depths[node_dep] = max(depth + 1, previous_depth)
|
|
|
|
# Handle inputs that are not connected to outputs.
|
|
# We do not error out here because the inputs may be used to compute losses
|
|
# and metrics.
|
|
for input_t in inputs:
|
|
input_layer = input_t._keras_history[0]
|
|
if input_layer not in layers_depths:
|
|
layers_depths[input_layer] = 0
|
|
layer_indices[input_layer] = -1
|
|
nodes_depths[input_layer._inbound_nodes[0]] = 0
|
|
network_nodes.add(_make_node_key(input_layer.name, 0))
|
|
|
|
# Build a dict {depth: list of nodes with this depth}
|
|
nodes_by_depth = collections.defaultdict(list)
|
|
for node, depth in nodes_depths.items():
|
|
nodes_by_depth[depth].append(node)
|
|
|
|
# Build a dict {depth: list of layers with this depth}
|
|
layers_by_depth = collections.defaultdict(list)
|
|
for layer, depth in layers_depths.items():
|
|
layers_by_depth[depth].append(layer)
|
|
|
|
# Get sorted list of layer depths.
|
|
depth_keys = list(layers_by_depth.keys())
|
|
depth_keys.sort(reverse=True)
|
|
|
|
# Set self.layers ordered by depth.
|
|
layers = []
|
|
for depth in depth_keys:
|
|
layers_for_depth = layers_by_depth[depth]
|
|
# Network.layers needs to have a deterministic order:
|
|
# here we order them by traversal order.
|
|
layers_for_depth.sort(key=lambda x: layer_indices[x])
|
|
layers.extend(layers_for_depth)
|
|
|
|
# Get sorted list of node depths.
|
|
depth_keys = list(nodes_by_depth.keys())
|
|
depth_keys.sort(reverse=True)
|
|
|
|
# Check that all tensors required are computable.
|
|
# computable_tensors: all tensors in the graph
|
|
# that can be computed from the inputs provided.
|
|
computable_tensors = set()
|
|
for x in inputs:
|
|
computable_tensors.add(id(x))
|
|
|
|
layers_with_complete_input = [] # To provide a better error msg.
|
|
for depth in depth_keys:
|
|
for node in nodes_by_depth[depth]:
|
|
layer = node.layer
|
|
if layer and not node.is_input:
|
|
for x in nest.flatten(node.keras_inputs):
|
|
if id(x) not in computable_tensors:
|
|
raise ValueError('Graph disconnected: '
|
|
'cannot obtain value for tensor ' + str(x) +
|
|
' at layer "' + layer.name + '". '
|
|
'The following previous layers '
|
|
'were accessed without issue: ' +
|
|
str(layers_with_complete_input))
|
|
for x in nest.flatten(node.outputs):
|
|
computable_tensors.add(id(x))
|
|
layers_with_complete_input.append(layer.name)
|
|
|
|
# Ensure name unicity, which will be crucial for serialization
|
|
# (since serialized nodes refer to layers by their name).
|
|
all_names = [layer.name for layer in layers]
|
|
for name in all_names:
|
|
if all_names.count(name) != 1:
|
|
raise ValueError('The name "' + name + '" is used ' +
|
|
str(all_names.count(name)) + ' times in the model. '
|
|
'All layer names should be unique.')
|
|
return network_nodes, nodes_by_depth, layers, layers_by_depth
|
|
|
|
|
|
def _build_map(outputs):
|
|
"""This method topologically sorts nodes in order from inputs to outputs.
|
|
|
|
It uses a depth-first search to topologically sort nodes that appear in the
|
|
_keras_history connectivity metadata of `outputs`.
|
|
|
|
Args:
|
|
outputs: the output tensors whose _keras_history metadata should be walked.
|
|
This may be an arbitrary nested structure.
|
|
|
|
Returns:
|
|
A tuple like (ordered_nodes, layer_to_first_traversal_index)
|
|
ordered_nodes: list of nodes appearing in the keras history, topologically
|
|
sorted from original inputs to the `outputs`.
|
|
(If outputs have different sets of ancestors, the inputs to one output
|
|
may appear after a different output).
|
|
layer_to_first_traversal_index:
|
|
A dict mapping layer to the traversal index in the DFS where it is
|
|
seen. Note: if a layer is shared by several nodes, the dict will only
|
|
store the index corresponding to the *first* time the layer seen.
|
|
"""
|
|
finished_nodes = set()
|
|
nodes_in_progress = set()
|
|
nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
|
|
layer_indices = {} # layer -> in traversal order.
|
|
for output in nest.flatten(outputs):
|
|
_build_map_helper(output, finished_nodes, nodes_in_progress,
|
|
nodes_in_decreasing_depth, layer_indices)
|
|
return nodes_in_decreasing_depth, layer_indices
|
|
|
|
|
|
def _build_map_helper(tensor, finished_nodes, nodes_in_progress,
|
|
nodes_in_decreasing_depth, layer_indices):
|
|
"""Recursive helper for `_build_map`."""
|
|
layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access
|
|
node = layer._inbound_nodes[node_index] # pylint: disable=protected-access
|
|
|
|
# Don't repeat work for shared subgraphs
|
|
if node in finished_nodes:
|
|
return
|
|
|
|
# Prevent cycles.
|
|
if node in nodes_in_progress:
|
|
raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name +
|
|
'" is part of a cycle.')
|
|
|
|
# Store the traversal order for layer sorting.
|
|
if layer not in layer_indices:
|
|
layer_indices[layer] = len(layer_indices)
|
|
|
|
# Propagate to all previous tensors connected to this node.
|
|
nodes_in_progress.add(node)
|
|
if not node.is_input:
|
|
for tensor in node.keras_inputs:
|
|
_build_map_helper(tensor, finished_nodes, nodes_in_progress,
|
|
nodes_in_decreasing_depth, layer_indices)
|
|
|
|
finished_nodes.add(node)
|
|
nodes_in_progress.remove(node)
|
|
nodes_in_decreasing_depth.append(node)
|
|
|
|
|
|
def _map_subgraph_network(inputs, outputs):
|
|
"""Returns the nodes and layers in the topology from `inputs` to `outputs`.
|
|
|
|
Args:
|
|
inputs: List of input tensors.
|
|
outputs: List of output tensors.
|
|
|
|
Returns:
|
|
A tuple of List{Node] and List[Layer].
|
|
"""
|
|
if not keras_tensor.keras_tensors_enabled():
|
|
base_layer_utils.create_keras_history(outputs)
|
|
# Keep only nodes and layers in the topology between inputs and outputs.
|
|
_, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
|
|
return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
|
|
|
|
|
|
def _should_skip_first_node(layer):
|
|
"""Returns True if the first layer node should not be saved or loaded."""
|
|
# Networks that are constructed with an Input layer/shape start with a
|
|
# pre-existing node linking their input to output. This node is excluded from
|
|
# the network config.
|
|
return (isinstance(layer, Functional) and
|
|
# Filter out Sequential models without an input shape.
|
|
isinstance(layer._layers[0], input_layer_module.InputLayer))
|
|
|
|
|
|
def _deserialize_keras_tensors(kwargs, layer_map):
|
|
"""Deserializes Keras Tensors passed to `call`.."""
|
|
|
|
def _deserialize_keras_tensor(t):
|
|
"""Deserializes a single Keras Tensor passed to `call`."""
|
|
if isinstance(t, tf_utils.ListWrapper):
|
|
t = t.as_list()
|
|
layer_name = t[0]
|
|
node_index = t[1]
|
|
tensor_index = t[2]
|
|
|
|
layer = layer_map[layer_name]
|
|
node = layer._inbound_nodes[node_index]
|
|
return nest.flatten(node.outputs)[tensor_index]
|
|
return t
|
|
|
|
kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
|
|
return nest.map_structure(_deserialize_keras_tensor, kwargs)
|
|
|
|
|
|
def connect_ancillary_layers(model, created_layers):
|
|
"""Adds layers that are not connected to the outputs to the model."""
|
|
# Layers not connected to outputs, such as those added in `add_loss`.
|
|
ancillary_layers = [
|
|
layer for layer in created_layers.values() if layer not in model.layers
|
|
]
|
|
if ancillary_layers:
|
|
relevant_nodes = nest.flatten([
|
|
layer.inbound_nodes[1:]
|
|
if _should_skip_first_node(layer) else layer.inbound_nodes
|
|
for layer in created_layers.values()
|
|
])
|
|
model._insert_layers(ancillary_layers, relevant_nodes)
|
|
return model
|
|
|
|
|
|
def reconstruct_from_config(config, custom_objects=None, created_layers=None):
|
|
"""Reconstructs graph from config object.
|
|
|
|
Args:
|
|
config: Dictionary returned from Network.get_config()
|
|
custom_objects: Optional dictionary mapping names (strings) to custom
|
|
classes or functions to be considered during deserialization.
|
|
created_layers: Optional dictionary mapping names to Layer objects. Any
|
|
layer not in this dictionary will be be created and added to the dict.
|
|
This function will add new nodes to all layers (excluding InputLayers),
|
|
instead of re-using pre-existing nodes in the layers.
|
|
|
|
Returns:
|
|
Tuple of (input tensors, output tensors, dictionary of created layers)
|
|
"""
|
|
# Layer instances created during the graph reconstruction process.
|
|
created_layers = created_layers or collections.OrderedDict()
|
|
|
|
# Maps input data (tuple of inbound layer name, node index) from the config
|
|
# to node indices in the newly generated model. The node indices may be
|
|
# different if the layers have already been called previously.
|
|
node_index_map = {}
|
|
node_count_by_layer = {}
|
|
|
|
# Dictionary mapping layer instances to
|
|
# node data that specifies a layer call.
|
|
# It acts as a queue that maintains any unprocessed
|
|
# layer call until it becomes possible to process it
|
|
# (i.e. until the input tensors to the call all exist).
|
|
unprocessed_nodes = {}
|
|
|
|
def add_unprocessed_node(layer, node_data):
|
|
if layer not in unprocessed_nodes:
|
|
unprocessed_nodes[layer] = [node_data]
|
|
else:
|
|
unprocessed_nodes[layer].append(node_data)
|
|
|
|
def get_node_index(layer, config_node_index):
|
|
"""Returns node index in layer (might differ from config_node_index)."""
|
|
if isinstance(layer, input_layer_module.InputLayer):
|
|
return 0
|
|
return node_index_map.get((layer.name, config_node_index), None)
|
|
|
|
def process_node(layer, node_data):
|
|
"""Deserialize a node.
|
|
|
|
Arguments:
|
|
layer: layer instance.
|
|
node_data: Nested structure of `ListWrapper`.
|
|
|
|
Raises:
|
|
ValueError: In case of improperly formatted `node_data`.
|
|
"""
|
|
input_tensors = []
|
|
for input_data in nest.flatten(node_data):
|
|
input_data = input_data.as_list()
|
|
inbound_layer_name = input_data[0]
|
|
inbound_node_index = input_data[1]
|
|
inbound_tensor_index = input_data[2]
|
|
if len(input_data) == 3:
|
|
kwargs = {}
|
|
elif len(input_data) == 4:
|
|
kwargs = input_data[3]
|
|
try:
|
|
kwargs = _deserialize_keras_tensors(kwargs, created_layers)
|
|
except IndexError:
|
|
# Happens if keras tensors in kwargs are still unprocessed
|
|
add_unprocessed_node(layer, node_data)
|
|
return
|
|
else:
|
|
raise ValueError('Improperly formatted model config.')
|
|
|
|
if inbound_layer_name != node_module._CONSTANT_VALUE:
|
|
inbound_layer = created_layers[inbound_layer_name]
|
|
inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
|
|
|
|
if inbound_node_index is None:
|
|
add_unprocessed_node(layer, node_data)
|
|
return
|
|
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
|
|
input_tensors.append(
|
|
nest.flatten(inbound_node.outputs)[inbound_tensor_index])
|
|
else:
|
|
# We received a constant w/ no Keras history attached
|
|
input_tensors.append(inbound_tensor_index)
|
|
input_tensors = nest.pack_sequence_as(node_data, input_tensors)
|
|
# Call layer on its inputs, thus creating the node
|
|
# and building the layer if needed.
|
|
if input_tensors is not None:
|
|
input_tensors = base_layer_utils.unnest_if_single_tensor(input_tensors)
|
|
output_tensors = layer(input_tensors, **kwargs)
|
|
|
|
# Update node index map.
|
|
output_index = nest.flatten(output_tensors)[0]._keras_history.node_index
|
|
node_index_map[(layer.name, node_count_by_layer[layer])] = output_index
|
|
node_count_by_layer[layer] += 1
|
|
|
|
def process_layer(layer_data):
|
|
"""Deserializes a layer, then call it on appropriate inputs.
|
|
|
|
Arguments:
|
|
layer_data: layer config dict.
|
|
|
|
Raises:
|
|
ValueError: In case of improperly formatted `layer_data` dict.
|
|
"""
|
|
layer_name = layer_data['name']
|
|
|
|
if layer_name in created_layers:
|
|
layer = created_layers[layer_name]
|
|
else:
|
|
# Instantiate layer.
|
|
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
|
|
|
layer = deserialize_layer(layer_data, custom_objects=custom_objects)
|
|
created_layers[layer_name] = layer
|
|
|
|
node_count_by_layer[layer] = int(_should_skip_first_node(layer))
|
|
|
|
# Gather layer inputs and convert to `ListWrapper` objects.
|
|
inbound_nodes_data = layer_data['inbound_nodes']
|
|
inbound_nodes_data = tf_utils.convert_inner_node_data(
|
|
inbound_nodes_data, wrap=True)
|
|
for node_data in inbound_nodes_data:
|
|
# We don't process nodes (i.e. make layer calls)
|
|
# on the fly because the inbound node may not yet exist,
|
|
# in case of layer shared at different topological depths
|
|
# (e.g. a model such as A(B(A(B(x)))))
|
|
add_unprocessed_node(layer, node_data)
|
|
|
|
# First, we create all layers and enqueue nodes to be processed
|
|
for layer_data in config['layers']:
|
|
process_layer(layer_data)
|
|
# Then we process nodes in order of layer depth.
|
|
# Nodes that cannot yet be processed (if the inbound node
|
|
# does not yet exist) are re-enqueued, and the process
|
|
# is repeated until all nodes are processed.
|
|
while unprocessed_nodes:
|
|
for layer_data in config['layers']:
|
|
layer = created_layers[layer_data['name']]
|
|
if layer in unprocessed_nodes:
|
|
for node_data in unprocessed_nodes.pop(layer):
|
|
process_node(layer, node_data)
|
|
|
|
input_tensors = []
|
|
output_tensors = []
|
|
|
|
input_layers = tf_utils.convert_inner_node_data(
|
|
config['input_layers'], wrap=True)
|
|
for layer_data in nest.flatten(input_layers):
|
|
layer_name, node_index, tensor_index = layer_data.as_list()
|
|
assert layer_name in created_layers
|
|
layer = created_layers[layer_name]
|
|
node_index = get_node_index(layer, node_index)
|
|
layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
|
|
input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
|
|
|
|
output_layers = tf_utils.convert_inner_node_data(
|
|
config['output_layers'], wrap=True)
|
|
for layer_data in nest.flatten(output_layers):
|
|
layer_name, node_index, tensor_index = layer_data.as_list()
|
|
assert layer_name in created_layers
|
|
layer = created_layers[layer_name]
|
|
node_index = get_node_index(layer, node_index)
|
|
layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
|
|
output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
|
|
|
|
input_tensors = nest.pack_sequence_as(input_layers, input_tensors)
|
|
output_tensors = nest.pack_sequence_as(output_layers, output_tensors)
|
|
return input_tensors, output_tensors, created_layers
|
|
|
|
|
|
def get_network_config(network, serialize_layer_fn=None):
|
|
"""Builds the config, which consists of the node graph and serialized layers.
|
|
|
|
Args:
|
|
network: A Network object.
|
|
serialize_layer_fn: Function used to serialize layers.
|
|
|
|
Returns:
|
|
Config dictionary.
|
|
"""
|
|
serialize_layer_fn = (
|
|
serialize_layer_fn or generic_utils.serialize_keras_object)
|
|
config = {
|
|
'name': network.name,
|
|
}
|
|
node_conversion_map = {}
|
|
for layer in network.layers:
|
|
kept_nodes = 1 if _should_skip_first_node(layer) else 0
|
|
for original_node_index, node in enumerate(layer._inbound_nodes):
|
|
node_key = _make_node_key(layer.name, original_node_index)
|
|
if node_key in network._network_nodes:
|
|
node_conversion_map[node_key] = kept_nodes
|
|
kept_nodes += 1
|
|
layer_configs = []
|
|
for layer in network.layers: # From the earliest layers on.
|
|
filtered_inbound_nodes = []
|
|
for original_node_index, node in enumerate(layer._inbound_nodes):
|
|
node_key = _make_node_key(layer.name, original_node_index)
|
|
if node_key in network._network_nodes and not node.is_input:
|
|
# The node is relevant to the model:
|
|
# add to filtered_inbound_nodes.
|
|
node_data = node.serialize(_make_node_key, node_conversion_map)
|
|
filtered_inbound_nodes.append(node_data)
|
|
|
|
layer_config = serialize_layer_fn(layer)
|
|
layer_config['name'] = layer.name
|
|
layer_config['inbound_nodes'] = filtered_inbound_nodes
|
|
layer_configs.append(layer_config)
|
|
config['layers'] = layer_configs
|
|
|
|
# Gather info about inputs and outputs.
|
|
model_inputs = []
|
|
for i in range(len(network._input_layers)):
|
|
layer, node_index, tensor_index = network._input_coordinates[i]
|
|
node_key = _make_node_key(layer.name, node_index)
|
|
if node_key not in network._network_nodes:
|
|
continue
|
|
new_node_index = node_conversion_map[node_key]
|
|
model_inputs.append(
|
|
tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
|
|
model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs)
|
|
# Preserve external Keras compat for Models with single input.
|
|
if not nest.is_sequence(model_inputs):
|
|
model_inputs = [model_inputs]
|
|
model_inputs = tf_utils.convert_inner_node_data(model_inputs)
|
|
config['input_layers'] = model_inputs
|
|
|
|
model_outputs = []
|
|
for i in range(len(network._output_layers)):
|
|
layer, node_index, tensor_index = network._output_coordinates[i]
|
|
node_key = _make_node_key(layer.name, node_index)
|
|
if node_key not in network._network_nodes:
|
|
continue
|
|
new_node_index = node_conversion_map[node_key]
|
|
model_outputs.append(
|
|
tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
|
|
model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs)
|
|
# Preserve external Keras compat for Models with single output.
|
|
if not nest.is_sequence(model_outputs):
|
|
model_outputs = [model_outputs]
|
|
model_outputs = tf_utils.convert_inner_node_data(model_outputs)
|
|
config['output_layers'] = model_outputs
|
|
return config
|