293 lines
11 KiB
Python
293 lines
11 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
# pylint: disable=protected-access
|
|
"""Contains the `Node` class."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import copy
|
|
import json
|
|
import numpy as np
|
|
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.keras import backend
|
|
from tensorflow.python.keras.engine import base_layer_utils
|
|
from tensorflow.python.keras.engine import keras_tensor
|
|
from tensorflow.python.keras.utils import tf_utils
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util import serialization
|
|
|
|
_CONSTANT_VALUE = '_CONSTANT_VALUE'
|
|
|
|
|
|
class Node(object):
|
|
"""A `Node` describes the connectivity between two layers.
|
|
|
|
Each time a layer is connected to some new input,
|
|
a node is added to `layer._inbound_nodes`.
|
|
Each time the output of a layer is used by another layer,
|
|
a node is added to `layer._outbound_nodes`.
|
|
|
|
Arguments:
|
|
layer: The Layer for the Layer.__call__ this node represents.
|
|
call_args: The positional arguments the Layer was called with.
|
|
call_kwargs: The keyword arguments the Layer was called with.
|
|
outputs: The outputs of the Layer.__call__
|
|
"""
|
|
|
|
def __init__(self,
|
|
layer,
|
|
call_args=None,
|
|
call_kwargs=None,
|
|
outputs=None):
|
|
call_args = [] if call_args is None else call_args
|
|
call_kwargs = {} if call_kwargs is None else call_kwargs
|
|
outputs = [] if outputs is None else outputs
|
|
|
|
self.layer = layer
|
|
self.is_input = not call_args and not call_kwargs
|
|
|
|
# These arguments are user-provided. Copy the structures here so that
|
|
# future user modifications do not affect the node's metadata.
|
|
# We copy using map_structure rather than python's shallow or deep copy,
|
|
# because the args can be data structures (so shallow copy is
|
|
# insufficient), but individual values might not support copy.copy
|
|
# or be too expensive to deep copy.
|
|
call_args = nest.map_structure(lambda t: t, call_args)
|
|
call_kwargs = nest.map_structure(lambda t: t, call_kwargs)
|
|
self.outputs = nest.map_structure(lambda t: t, outputs)
|
|
self.call_args = call_args
|
|
self.call_kwargs = call_kwargs
|
|
|
|
# Cached for performance.
|
|
self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs))
|
|
# Used to avoid expensive `nest` operations in the most common case.
|
|
self._single_positional_tensor_passed = (not self.call_kwargs and len(
|
|
self.call_args) == 1 and tensor_util.is_tensor(self.call_args[0]))
|
|
|
|
if not keras_tensor.keras_tensors_enabled():
|
|
# Create TensorFlowOpLayers if needed.
|
|
for obj in self._flat_arguments:
|
|
if (isinstance(obj, ops.Tensor) and
|
|
base_layer_utils.needs_keras_history(
|
|
obj, ignore_call_context=True)):
|
|
base_layer_utils.create_keras_history(obj)
|
|
|
|
self._keras_inputs = []
|
|
self._keras_inputs_ids_and_indices = []
|
|
for i, ele in enumerate(self._flat_arguments):
|
|
if is_keras_tensor(ele):
|
|
self._keras_inputs.append(ele)
|
|
kt_id = str(id(ele))
|
|
kt_index = i
|
|
self._keras_inputs_ids_and_indices.append((kt_id, kt_index))
|
|
|
|
# Wire up Node to Layers.
|
|
self.layer._inbound_nodes.append(self)
|
|
for kt in self.keras_inputs:
|
|
inbound_layer = kt._keras_history.layer
|
|
if inbound_layer is not None: # `None` for `Input` tensors.
|
|
inbound_layer._outbound_nodes.append(self)
|
|
|
|
# Set metadata on outputs.
|
|
node_index = len(self.layer._inbound_nodes) - 1
|
|
for i, tensor in enumerate(nest.flatten(outputs)):
|
|
tensor._keras_history = KerasHistory(
|
|
layer=layer, node_index=node_index, tensor_index=i)
|
|
|
|
# Cached for performance.
|
|
self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
|
|
self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)]
|
|
|
|
@property
|
|
def keras_inputs(self):
|
|
"""Tensors input to this node that can be traced back to a `keras.Input`."""
|
|
return self._keras_inputs
|
|
|
|
@property
|
|
def parent_nodes(self):
|
|
"""Returns all the `Node`s whose output this node immediately depends on."""
|
|
node_deps = []
|
|
for kt in self.keras_inputs:
|
|
layer = kt._keras_history.layer
|
|
node_index = kt._keras_history.node_index
|
|
if layer is not None: # `None` for `Input` tensors.
|
|
node_deps.append(layer._inbound_nodes[node_index])
|
|
return node_deps
|
|
|
|
def iterate_inbound(self):
|
|
"""Yields tuples representing the data inbound from other nodes.
|
|
|
|
Yields:
|
|
tuples like: (inbound_layer, node_index, tensor_index, tensor).
|
|
"""
|
|
for kt in self.keras_inputs:
|
|
keras_history = kt._keras_history
|
|
layer = keras_history.layer
|
|
node_index = keras_history.node_index
|
|
tensor_index = keras_history.tensor_index
|
|
yield layer, node_index, tensor_index, kt
|
|
|
|
def map_arguments(self, tensor_dict):
|
|
"""Maps Keras Tensors to computed Tensors using `tensor_dict`."""
|
|
if self._single_positional_tensor_passed:
|
|
# Performance optimization for most common case.
|
|
kt_id, _ = self._keras_inputs_ids_and_indices[0]
|
|
return (tensor_dict[kt_id].pop(),), {}
|
|
else:
|
|
flat_arguments = copy.copy(self._flat_arguments)
|
|
for kt_id, kt_index in self._keras_inputs_ids_and_indices:
|
|
flat_arguments[kt_index] = tensor_dict[kt_id].pop()
|
|
|
|
args, kwargs = nest.pack_sequence_as((self.call_args, self.call_kwargs),
|
|
flat_arguments)
|
|
return args, kwargs
|
|
|
|
def serialize(self, make_node_key, node_conversion_map):
|
|
"""Serializes `Node` for Functional API's `get_config`."""
|
|
# Serialization still special-cases first argument.
|
|
args, kwargs = self.call_args, self.call_kwargs
|
|
inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs)
|
|
|
|
# Treat everything other than first argument as a kwarg.
|
|
arguments = dict(zip(self.layer._call_fn_args[1:], args))
|
|
arguments.update(kwargs)
|
|
kwargs = arguments
|
|
|
|
kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
|
|
try:
|
|
json.dumps(kwargs, default=serialization.get_json_type)
|
|
except TypeError:
|
|
kwarg_types = nest.map_structure(type, kwargs)
|
|
logging.warning('Layer ' + self.layer.name +
|
|
' was passed non-JSON-serializable arguments. ' +
|
|
'Arguments had types: ' +
|
|
str(kwarg_types) + '. They will not be included '
|
|
'in the serialized model (and thus will be missing '
|
|
'at deserialization time).')
|
|
kwargs = {}
|
|
|
|
# `kwargs` is added to each Tensor in the first arg. This should be
|
|
# changed in a future version of the serialization format.
|
|
def serialize_first_arg_tensor(t):
|
|
if is_keras_tensor(t):
|
|
kh = t._keras_history
|
|
node_index = kh.node_index
|
|
node_key = make_node_key(kh.layer.name, node_index)
|
|
new_node_index = node_conversion_map.get(node_key, 0)
|
|
data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
|
|
else:
|
|
# If an element in the first call argument did not originate as a
|
|
# keras tensor and is a constant value, we save it using the format
|
|
# ['_CONSTANT_VALUE', -1, serializaed_tensor_or_python_constant]
|
|
# (potentially including serialized kwargs in an optional 4th argument
|
|
data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
|
|
return tf_utils.ListWrapper(data)
|
|
|
|
data = nest.map_structure(serialize_first_arg_tensor, inputs)
|
|
if not nest.is_sequence(data):
|
|
data = [data]
|
|
data = tf_utils.convert_inner_node_data(data)
|
|
return data
|
|
|
|
#############################################################
|
|
# Properties for Backwards compatibility.
|
|
# These only check the first input argument
|
|
# As nodes are internal, they may be removed in the future.
|
|
#############################################################
|
|
|
|
@property
|
|
def input_tensors(self):
|
|
if self.is_input:
|
|
return [self.outputs] # Used in `Layer.input`.
|
|
return self.call_args[0]
|
|
|
|
@property
|
|
def output_tensors(self):
|
|
if self.is_input:
|
|
return [self.outputs] # Used in `Layer.input`.
|
|
return self.outputs
|
|
|
|
@property
|
|
def input_shapes(self):
|
|
input_shapes = nest.map_structure(backend.int_shape, self.input_tensors)
|
|
if len(input_shapes) == 1 and not self.is_input:
|
|
return input_shapes[0]
|
|
return input_shapes
|
|
|
|
@property
|
|
def output_shapes(self):
|
|
return nest.map_structure(backend.int_shape, self.output_tensors)
|
|
|
|
@property
|
|
def outbound_layer(self):
|
|
return self.layer
|
|
|
|
@property
|
|
def inbound_layers(self):
|
|
if self.is_input:
|
|
return []
|
|
inbound_layers = nest.map_structure(lambda t: t._keras_history.layer,
|
|
self.call_args[0])
|
|
return inbound_layers
|
|
|
|
|
|
class KerasHistory(
|
|
collections.namedtuple('KerasHistory',
|
|
['layer', 'node_index', 'tensor_index'])):
|
|
"""Tracks the Layer call that created a Tensor, for Keras Graph Networks.
|
|
|
|
During construction of Keras Graph Networks, this metadata is added to
|
|
each Tensor produced as the output of a Layer, starting with an
|
|
`InputLayer`. This allows Keras to track how each Tensor was produced, and
|
|
this information is later retraced by the `keras.engine.Network` class to
|
|
reconstruct the Keras Graph Network.
|
|
|
|
Attributes:
|
|
layer: The Layer that produced the Tensor.
|
|
node_index: The specific call to the Layer that produced this Tensor. Layers
|
|
can be called multiple times in order to share weights. A new node is
|
|
created every time a Layer is called.
|
|
tensor_index: The output index for this Tensor. Always zero if the Layer
|
|
that produced this Tensor only has one output. Nested structures of
|
|
Tensors are deterministically assigned an index via `nest.flatten`.
|
|
"""
|
|
# Added to maintain memory and performance characteristics of `namedtuple`
|
|
# while subclassing.
|
|
__slots__ = ()
|
|
|
|
|
|
def is_keras_tensor(obj):
|
|
return hasattr(obj, '_keras_history')
|
|
|
|
|
|
def _serialize_keras_tensor(t):
|
|
"""Serializes a single Tensor passed to `call`."""
|
|
if hasattr(t, '_keras_history'):
|
|
kh = t._keras_history
|
|
return [kh.layer.name, kh.node_index, kh.tensor_index]
|
|
|
|
if isinstance(t, np.ndarray):
|
|
return t.tolist()
|
|
|
|
if isinstance(t, ops.Tensor):
|
|
return backend.get_value(t).tolist()
|
|
|
|
return t
|