Move more layer related utils to frozen_keras.
PiperOrigin-RevId: 299879819 Change-Id: Iebf2ea84be845bb09b926bdb3391375965e150aa
This commit is contained in:
parent
9a7288ccf5
commit
105c91f2c4
@ -12,6 +12,8 @@ py_library(
|
|||||||
srcs = ["legacy_base_layer.py"],
|
srcs = ["legacy_base_layer.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":base_layer_utils",
|
":base_layer_utils",
|
||||||
|
":input_spec",
|
||||||
|
":node",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:auto_control_deps",
|
"//tensorflow/python:auto_control_deps",
|
||||||
@ -38,7 +40,6 @@ py_library(
|
|||||||
"//tensorflow/python/keras:backend",
|
"//tensorflow/python/keras:backend",
|
||||||
"//tensorflow/python/keras:metrics",
|
"//tensorflow/python/keras:metrics",
|
||||||
"//tensorflow/python/keras/engine",
|
"//tensorflow/python/keras/engine",
|
||||||
"//tensorflow/python/keras/engine:input_spec",
|
|
||||||
"//tensorflow/python/keras/saving",
|
"//tensorflow/python/keras/saving",
|
||||||
"//tensorflow/python/keras/utils:generic_utils",
|
"//tensorflow/python/keras/utils:generic_utils",
|
||||||
"//tensorflow/python/keras/utils:layer_utils",
|
"//tensorflow/python/keras/utils:layer_utils",
|
||||||
@ -78,6 +79,31 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "input_spec",
|
||||||
|
srcs = ["input_spec.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:lib",
|
||||||
|
"//tensorflow/python:tensor_shape",
|
||||||
|
"//tensorflow/python:tensor_spec",
|
||||||
|
"//tensorflow/python/keras:backend",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "node",
|
||||||
|
srcs = ["node.py"],
|
||||||
|
deps = [
|
||||||
|
":base_layer_utils",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/keras:backend",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "legacy_base_layer_test",
|
name = "legacy_base_layer_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
@ -110,3 +136,18 @@ tf_py_test(
|
|||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "input_spec_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["input_spec_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
tags = [
|
||||||
|
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":input_spec",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
233
tensorflow/python/frozen_keras/engine/input_spec.py
Normal file
233
tensorflow/python/frozen_keras/engine/input_spec.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
# pylint: disable=g-classes-have-attributes
|
||||||
|
"""Contains the InputSpec class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from six.moves import zip # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.keras import backend
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
class InputSpec(object):
|
||||||
|
"""Specifies the rank, dtype and shape of every input to a layer.
|
||||||
|
|
||||||
|
Layers can expose (if appropriate) an `input_spec` attribute:
|
||||||
|
an instance of `InputSpec`, or a nested structure of `InputSpec` instances
|
||||||
|
(one per input tensor). These objects enable the layer to run input
|
||||||
|
compatibility checks for input structure, input rank, input shape, and
|
||||||
|
input dtype.
|
||||||
|
|
||||||
|
A None entry in a shape is compatible with any dimension,
|
||||||
|
a None shape is compatible with any shape.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
dtype: Expected DataType of the input.
|
||||||
|
shape: Shape tuple, expected shape of the input
|
||||||
|
(may include None for unchecked axes).
|
||||||
|
ndim: Integer, expected rank of the input.
|
||||||
|
max_ndim: Integer, maximum rank of the input.
|
||||||
|
min_ndim: Integer, minimum rank of the input.
|
||||||
|
axes: Dictionary mapping integer axes to
|
||||||
|
a specific dimension value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dtype=None,
|
||||||
|
shape=None,
|
||||||
|
ndim=None,
|
||||||
|
max_ndim=None,
|
||||||
|
min_ndim=None,
|
||||||
|
axes=None):
|
||||||
|
self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
|
||||||
|
if shape is not None:
|
||||||
|
self.ndim = len(shape)
|
||||||
|
self.shape = shape
|
||||||
|
else:
|
||||||
|
self.ndim = ndim
|
||||||
|
self.shape = None
|
||||||
|
self.max_ndim = max_ndim
|
||||||
|
self.min_ndim = min_ndim
|
||||||
|
try:
|
||||||
|
axes = axes or {}
|
||||||
|
self.axes = {int(k): axes[k] for k in axes}
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise TypeError('The keys in axes must be integers.')
|
||||||
|
|
||||||
|
if self.axes and (self.ndim is not None or self.max_ndim is not None):
|
||||||
|
max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
|
||||||
|
max_axis = max(self.axes)
|
||||||
|
if max_axis > max_dim:
|
||||||
|
raise ValueError('Axis {} is greater than the maximum allowed value: {}'
|
||||||
|
.format(max_axis, max_dim))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
|
||||||
|
('shape=' + str(self.shape)) if self.shape else '',
|
||||||
|
('ndim=' + str(self.ndim)) if self.ndim else '',
|
||||||
|
('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
|
||||||
|
('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
|
||||||
|
('axes=' + str(self.axes)) if self.axes else '']
|
||||||
|
return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
'dtype': self.dtype,
|
||||||
|
'shape': self.shape,
|
||||||
|
'ndim': self.ndim,
|
||||||
|
'max_ndim': self.max_ndim,
|
||||||
|
'min_ndim': self.min_ndim,
|
||||||
|
'axes': self.axes}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
|
||||||
|
def to_tensor_shape(spec):
|
||||||
|
"""Returns a tf.TensorShape object that matches the shape specifications.
|
||||||
|
|
||||||
|
If the InputSpec's shape or ndim is defined, this method will return a fully
|
||||||
|
or partially-known shape. Otherwise, the returned TensorShape is None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec: an InputSpec object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a tf.TensorShape object
|
||||||
|
"""
|
||||||
|
if spec.ndim is None and spec.shape is None:
|
||||||
|
return tensor_shape.TensorShape(None)
|
||||||
|
elif spec.shape is not None:
|
||||||
|
return tensor_shape.TensorShape(spec.shape)
|
||||||
|
else:
|
||||||
|
shape = [None] * spec.ndim
|
||||||
|
for a in spec.axes:
|
||||||
|
shape[a] = spec.axes[a] # Assume that axes is defined
|
||||||
|
return tensor_shape.TensorShape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_input_compatibility(input_spec, inputs, layer_name):
|
||||||
|
"""Checks compatibility between the layer and provided inputs.
|
||||||
|
|
||||||
|
This checks that the tensor(s) `inputs` verify the input assumptions
|
||||||
|
of a layer (if any). If not, a clear and actional exception gets raised.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
input_spec: An InputSpec instance, list of InputSpec instances, a nested
|
||||||
|
structure of InputSpec instances, or None.
|
||||||
|
inputs: Input tensor, list of input tensors, or a nested structure of
|
||||||
|
input tensors.
|
||||||
|
layer_name: String, name of the layer (for error message formatting).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: in case of mismatch between
|
||||||
|
the provided inputs and the expectations of the layer.
|
||||||
|
"""
|
||||||
|
if not input_spec:
|
||||||
|
return
|
||||||
|
|
||||||
|
inputs = nest.flatten(inputs)
|
||||||
|
input_spec = nest.flatten(input_spec)
|
||||||
|
if len(inputs) != len(input_spec):
|
||||||
|
raise ValueError('Layer ' + layer_name + ' expects ' +
|
||||||
|
str(len(input_spec)) + ' inputs, '
|
||||||
|
'but it received ' + str(len(inputs)) +
|
||||||
|
' input tensors. Inputs received: ' + str(inputs))
|
||||||
|
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
|
||||||
|
if spec is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (spec.ndim is not None or
|
||||||
|
spec.min_ndim is not None or
|
||||||
|
spec.max_ndim is not None):
|
||||||
|
if x.shape.ndims is None:
|
||||||
|
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||||
|
layer_name + ' is incompatible with the layer: '
|
||||||
|
'its rank is undefined, but the layer requires a '
|
||||||
|
'defined rank.')
|
||||||
|
|
||||||
|
# Check ndim.
|
||||||
|
if spec.ndim is not None:
|
||||||
|
ndim = x.shape.ndims
|
||||||
|
if ndim != spec.ndim:
|
||||||
|
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||||
|
layer_name + ' is incompatible with the layer: '
|
||||||
|
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
|
||||||
|
str(ndim) + '. Full shape received: ' +
|
||||||
|
str(x.shape.as_list()))
|
||||||
|
if spec.max_ndim is not None:
|
||||||
|
ndim = x.shape.ndims
|
||||||
|
if ndim is not None and ndim > spec.max_ndim:
|
||||||
|
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||||
|
layer_name + ' is incompatible with the layer: '
|
||||||
|
'expected max_ndim=' + str(spec.max_ndim) +
|
||||||
|
', found ndim=' + str(ndim))
|
||||||
|
if spec.min_ndim is not None:
|
||||||
|
ndim = x.shape.ndims
|
||||||
|
if ndim is not None and ndim < spec.min_ndim:
|
||||||
|
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||||
|
layer_name + ' is incompatible with the layer: '
|
||||||
|
': expected min_ndim=' + str(spec.min_ndim) +
|
||||||
|
', found ndim=' + str(ndim) +
|
||||||
|
'. Full shape received: ' +
|
||||||
|
str(x.shape.as_list()))
|
||||||
|
# Check dtype.
|
||||||
|
if spec.dtype is not None:
|
||||||
|
if x.dtype != spec.dtype:
|
||||||
|
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||||
|
layer_name + ' is incompatible with the layer: '
|
||||||
|
'expected dtype=' + str(spec.dtype) +
|
||||||
|
', found dtype=' + str(x.dtype))
|
||||||
|
# Check specific shape axes.
|
||||||
|
if spec.axes:
|
||||||
|
shape = x.shape.as_list()
|
||||||
|
if shape is not None:
|
||||||
|
for axis, value in spec.axes.items():
|
||||||
|
if hasattr(value, 'value'):
|
||||||
|
value = value.value
|
||||||
|
if value is not None and shape[int(axis)] not in {value, None}:
|
||||||
|
raise ValueError(
|
||||||
|
'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
|
||||||
|
' incompatible with the layer: expected axis ' + str(axis) +
|
||||||
|
' of input shape to have value ' + str(value) +
|
||||||
|
' but received input with shape ' + str(shape))
|
||||||
|
# Check shape.
|
||||||
|
if spec.shape is not None:
|
||||||
|
shape = x.shape.as_list()
|
||||||
|
if shape is not None:
|
||||||
|
for spec_dim, dim in zip(spec.shape, shape):
|
||||||
|
if spec_dim is not None and dim is not None:
|
||||||
|
if spec_dim != dim:
|
||||||
|
raise ValueError('Input ' + str(input_index) +
|
||||||
|
' is incompatible with layer ' + layer_name +
|
||||||
|
': expected shape=' + str(spec.shape) +
|
||||||
|
', found shape=' + str(shape))
|
||||||
|
|
||||||
|
|
||||||
|
def to_tensor_spec(input_spec, default_dtype=None):
|
||||||
|
"""Converts a Keras InputSpec object to a TensorSpec."""
|
||||||
|
default_dtype = default_dtype or backend.floatx()
|
||||||
|
if isinstance(input_spec, InputSpec):
|
||||||
|
dtype = input_spec.dtype or default_dtype
|
||||||
|
return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
|
||||||
|
return tensor_spec.TensorSpec(None, default_dtype)
|
66
tensorflow/python/frozen_keras/engine/input_spec_test.py
Normal file
66
tensorflow/python/frozen_keras/engine/input_spec_test.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""InputSpec tests."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.frozen_keras.engine import input_spec
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class InputSpecTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_axes_initialization(self):
|
||||||
|
input_spec.InputSpec(shape=[1, None, 2, 3], axes={3: 5, '2': 2})
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Axis 4 is greater than'):
|
||||||
|
input_spec.InputSpec(shape=[1, None, 2, 3], axes={4: 5})
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'keys in axes must be integers'):
|
||||||
|
input_spec.InputSpec(shape=[1, None, 2, 3], axes={'string': 5})
|
||||||
|
|
||||||
|
|
||||||
|
class InputSpecToTensorShapeTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_defined_shape(self):
|
||||||
|
spec = input_spec.InputSpec(shape=[1, None, 2, 3])
|
||||||
|
self.assertAllEqual(
|
||||||
|
[1, None, 2, 3], input_spec.to_tensor_shape(spec).as_list())
|
||||||
|
|
||||||
|
def test_defined_ndims(self):
|
||||||
|
spec = input_spec.InputSpec(ndim=5)
|
||||||
|
self.assertAllEqual(
|
||||||
|
[None] * 5, input_spec.to_tensor_shape(spec).as_list())
|
||||||
|
|
||||||
|
spec = input_spec.InputSpec(ndim=0)
|
||||||
|
self.assertAllEqual(
|
||||||
|
[], input_spec.to_tensor_shape(spec).as_list())
|
||||||
|
|
||||||
|
spec = input_spec.InputSpec(ndim=3, axes={1: 3, -1: 2})
|
||||||
|
self.assertAllEqual(
|
||||||
|
[None, 3, 2], input_spec.to_tensor_shape(spec).as_list())
|
||||||
|
|
||||||
|
def test_undefined_shapes(self):
|
||||||
|
spec = input_spec.InputSpec(max_ndim=5)
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
|
||||||
|
input_spec.to_tensor_shape(spec).as_list()
|
||||||
|
|
||||||
|
spec = input_spec.InputSpec(min_ndim=5, max_ndim=5)
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
|
||||||
|
input_spec.to_tensor_shape(spec).as_list()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -54,10 +54,10 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.frozen_keras import constraints
|
from tensorflow.python.frozen_keras import constraints
|
||||||
from tensorflow.python.frozen_keras import initializers
|
from tensorflow.python.frozen_keras import initializers
|
||||||
from tensorflow.python.frozen_keras import regularizers
|
from tensorflow.python.frozen_keras import regularizers
|
||||||
|
from tensorflow.python.frozen_keras.engine import base_layer_utils
|
||||||
|
from tensorflow.python.frozen_keras.engine import input_spec
|
||||||
|
from tensorflow.python.frozen_keras.engine import node as node_module
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
|
||||||
from tensorflow.python.keras.engine import input_spec
|
|
||||||
from tensorflow.python.keras.engine import node as node_module
|
|
||||||
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils import layer_utils
|
from tensorflow.python.keras.utils import layer_utils
|
||||||
|
189
tensorflow/python/frozen_keras/engine/node.py
Normal file
189
tensorflow/python/frozen_keras/engine/node.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
"""Contains the `Node` class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.frozen_keras.engine import base_layer_utils
|
||||||
|
from tensorflow.python.keras import backend
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
class Node(object):
|
||||||
|
"""A `Node` describes the connectivity between two layers.
|
||||||
|
|
||||||
|
Each time a layer is connected to some new input,
|
||||||
|
a node is added to `layer._inbound_nodes`.
|
||||||
|
Each time the output of a layer is used by another layer,
|
||||||
|
a node is added to `layer._outbound_nodes`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
outbound_layer: the layer that takes
|
||||||
|
`input_tensors` and turns them into `output_tensors`
|
||||||
|
(the node gets created when the `call`
|
||||||
|
method of the layer was called).
|
||||||
|
inbound_layers: a list of layers, the same length as `input_tensors`,
|
||||||
|
the layers from where `input_tensors` originate.
|
||||||
|
node_indices: a list of integers, the same length as `inbound_layers`.
|
||||||
|
`node_indices[i]` is the origin node of `input_tensors[i]`
|
||||||
|
(necessary since each inbound layer might have several nodes,
|
||||||
|
e.g. if the layer is being shared with a different data stream).
|
||||||
|
tensor_indices: a list of integers,
|
||||||
|
the same length as `inbound_layers`.
|
||||||
|
`tensor_indices[i]` is the index of `input_tensors[i]` within the
|
||||||
|
output of the inbound layer
|
||||||
|
(necessary since each inbound layer might
|
||||||
|
have multiple tensor outputs, with each one being
|
||||||
|
independently manipulable).
|
||||||
|
input_tensors: list of input tensors.
|
||||||
|
output_tensors: list of output tensors.
|
||||||
|
arguments: dictionary of keyword arguments that were passed to the
|
||||||
|
`call` method of the layer at the call that created the node.
|
||||||
|
|
||||||
|
`node_indices` and `tensor_indices` are basically fine-grained coordinates
|
||||||
|
describing the origin of the `input_tensors`.
|
||||||
|
|
||||||
|
A node from layer A to layer B is added to:
|
||||||
|
- A._outbound_nodes
|
||||||
|
- B._inbound_nodes
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
outbound_layer,
|
||||||
|
inbound_layers,
|
||||||
|
node_indices,
|
||||||
|
tensor_indices,
|
||||||
|
input_tensors,
|
||||||
|
output_tensors,
|
||||||
|
arguments=None):
|
||||||
|
# Layer instance (NOT a sequence)
|
||||||
|
if isinstance(outbound_layer, (list, tuple, dict)):
|
||||||
|
raise ValueError('`outbound_layer` should be a layer instance, '
|
||||||
|
'not a list, tuple, or, dict.')
|
||||||
|
|
||||||
|
# this is the layer that takes a nested structure of input tensors
|
||||||
|
# and turns them into a nested structure of output tensors.
|
||||||
|
# the current node will be added to
|
||||||
|
# the inbound_nodes of outbound_layer.
|
||||||
|
self.outbound_layer = outbound_layer
|
||||||
|
|
||||||
|
# The following 3 properties describe where
|
||||||
|
# the input tensors come from: which layers,
|
||||||
|
# and for each layer, which node and which
|
||||||
|
# tensor output of each node.
|
||||||
|
|
||||||
|
# Nested structure of layer instances.
|
||||||
|
self.inbound_layers = inbound_layers
|
||||||
|
# Nested structure of integers, 1:1 mapping with inbound_layers.
|
||||||
|
self.node_indices = node_indices
|
||||||
|
# Nested of integers, 1:1 mapping with inbound_layers.
|
||||||
|
self.tensor_indices = tensor_indices
|
||||||
|
|
||||||
|
# Following 2 properties:
|
||||||
|
# tensor inputs and outputs of outbound_layer.
|
||||||
|
|
||||||
|
# Nested structure of tensors. 1:1 mapping with inbound_layers.
|
||||||
|
self.input_tensors = input_tensors
|
||||||
|
# Nested structure of tensors, created by outbound_layer.call().
|
||||||
|
self.output_tensors = output_tensors
|
||||||
|
|
||||||
|
# Following 2 properties: input and output shapes.
|
||||||
|
|
||||||
|
# Nested structure of shape tuples, shapes of input_tensors.
|
||||||
|
self.input_shapes = nest.map_structure(backend.int_shape, input_tensors)
|
||||||
|
# Nested structure of shape tuples, shapes of output_tensors.
|
||||||
|
self.output_shapes = nest.map_structure(backend.int_shape, output_tensors)
|
||||||
|
|
||||||
|
# Optional keyword arguments to layer's `call`.
|
||||||
|
self.arguments = arguments
|
||||||
|
|
||||||
|
# Create Keras History for any Keras Tensors in `arguments`.
|
||||||
|
tensor_arguments = [
|
||||||
|
t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor)
|
||||||
|
]
|
||||||
|
for tensor_argument in tensor_arguments:
|
||||||
|
if base_layer_utils.needs_keras_history(
|
||||||
|
tensor_argument, ignore_call_context=True):
|
||||||
|
base_layer_utils.create_keras_history(tensor_argument)
|
||||||
|
|
||||||
|
# Add nodes to all layers involved.
|
||||||
|
for layer in nest.flatten(inbound_layers):
|
||||||
|
if layer is not None:
|
||||||
|
# For compatibility with external Keras, we use the deprecated
|
||||||
|
# accessor here.
|
||||||
|
layer.outbound_nodes.append(self)
|
||||||
|
# For compatibility with external Keras, we use the deprecated
|
||||||
|
# accessor here.
|
||||||
|
outbound_layer.inbound_nodes.append(self)
|
||||||
|
|
||||||
|
def iterate_inbound(self, include_arguments=False):
|
||||||
|
"""Returns a list of tuples representing the inbound data.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
include_arguments: Whether to also iterate over any Keras Tensors
|
||||||
|
passed as args, kwargs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
|
||||||
|
"""
|
||||||
|
inputs_inbound = list(
|
||||||
|
zip(
|
||||||
|
nest.flatten(self.inbound_layers),
|
||||||
|
nest.flatten(self.node_indices),
|
||||||
|
nest.flatten(self.tensor_indices),
|
||||||
|
nest.flatten(self.input_tensors)))
|
||||||
|
|
||||||
|
if include_arguments:
|
||||||
|
keras_tensor_arguments = [
|
||||||
|
kt for kt in nest.flatten(self.arguments)
|
||||||
|
if hasattr(kt, '_keras_history')
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_inbound(keras_tensor):
|
||||||
|
kh = keras_tensor._keras_history
|
||||||
|
return kh.layer, kh.node_index, kh.tensor_index, keras_tensor
|
||||||
|
|
||||||
|
arguments_inbound = nest.map_structure(_get_inbound,
|
||||||
|
keras_tensor_arguments)
|
||||||
|
|
||||||
|
return inputs_inbound + arguments_inbound
|
||||||
|
else:
|
||||||
|
return inputs_inbound
|
||||||
|
|
||||||
|
def _get_all_node_dependencies(self):
|
||||||
|
"""Returns all of the nodes this node immediately depends on."""
|
||||||
|
node_deps = []
|
||||||
|
for layer, node_index, _, _ in self.iterate_inbound():
|
||||||
|
node_deps.append(layer._inbound_nodes[node_index])
|
||||||
|
|
||||||
|
for arg in nest.flatten(self.arguments):
|
||||||
|
if isinstance(arg, ops.Tensor) and hasattr(arg, '_keras_history'):
|
||||||
|
kh = arg._keras_history
|
||||||
|
node_deps.append(kh.layer._inbound_nodes[kh.node_index])
|
||||||
|
|
||||||
|
return node_deps
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
inbound_names = nest.map_structure(
|
||||||
|
lambda layer: layer.name if layer else None, self.inbound_layers)
|
||||||
|
return {
|
||||||
|
'outbound_layer': self.outbound_layer.name,
|
||||||
|
'inbound_layers': inbound_names,
|
||||||
|
'node_indices': self.node_indices,
|
||||||
|
'tensor_indices': self.tensor_indices
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user