Move more layer related utils to frozen_keras.
PiperOrigin-RevId: 299879819 Change-Id: Iebf2ea84be845bb09b926bdb3391375965e150aa
This commit is contained in:
parent
9a7288ccf5
commit
105c91f2c4
tensorflow/python/frozen_keras/engine
@ -12,6 +12,8 @@ py_library(
|
||||
srcs = ["legacy_base_layer.py"],
|
||||
deps = [
|
||||
":base_layer_utils",
|
||||
":input_spec",
|
||||
":node",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:auto_control_deps",
|
||||
@ -38,7 +40,6 @@ py_library(
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:metrics",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/engine:input_spec",
|
||||
"//tensorflow/python/keras/saving",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
"//tensorflow/python/keras/utils:layer_utils",
|
||||
@ -78,6 +79,31 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "input_spec",
|
||||
srcs = ["input_spec.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "node",
|
||||
srcs = ["node.py"],
|
||||
deps = [
|
||||
":base_layer_utils",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/keras:backend",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "legacy_base_layer_test",
|
||||
size = "medium",
|
||||
@ -110,3 +136,18 @@ tf_py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "input_spec_test",
|
||||
size = "small",
|
||||
srcs = ["input_spec_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
deps = [
|
||||
":input_spec",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
233
tensorflow/python/frozen_keras/engine/input_spec.py
Normal file
233
tensorflow/python/frozen_keras/engine/input_spec.py
Normal file
@ -0,0 +1,233 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
"""Contains the InputSpec class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class InputSpec(object):
|
||||
"""Specifies the rank, dtype and shape of every input to a layer.
|
||||
|
||||
Layers can expose (if appropriate) an `input_spec` attribute:
|
||||
an instance of `InputSpec`, or a nested structure of `InputSpec` instances
|
||||
(one per input tensor). These objects enable the layer to run input
|
||||
compatibility checks for input structure, input rank, input shape, and
|
||||
input dtype.
|
||||
|
||||
A None entry in a shape is compatible with any dimension,
|
||||
a None shape is compatible with any shape.
|
||||
|
||||
Arguments:
|
||||
dtype: Expected DataType of the input.
|
||||
shape: Shape tuple, expected shape of the input
|
||||
(may include None for unchecked axes).
|
||||
ndim: Integer, expected rank of the input.
|
||||
max_ndim: Integer, maximum rank of the input.
|
||||
min_ndim: Integer, minimum rank of the input.
|
||||
axes: Dictionary mapping integer axes to
|
||||
a specific dimension value.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dtype=None,
|
||||
shape=None,
|
||||
ndim=None,
|
||||
max_ndim=None,
|
||||
min_ndim=None,
|
||||
axes=None):
|
||||
self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
|
||||
if shape is not None:
|
||||
self.ndim = len(shape)
|
||||
self.shape = shape
|
||||
else:
|
||||
self.ndim = ndim
|
||||
self.shape = None
|
||||
self.max_ndim = max_ndim
|
||||
self.min_ndim = min_ndim
|
||||
try:
|
||||
axes = axes or {}
|
||||
self.axes = {int(k): axes[k] for k in axes}
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError('The keys in axes must be integers.')
|
||||
|
||||
if self.axes and (self.ndim is not None or self.max_ndim is not None):
|
||||
max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
|
||||
max_axis = max(self.axes)
|
||||
if max_axis > max_dim:
|
||||
raise ValueError('Axis {} is greater than the maximum allowed value: {}'
|
||||
.format(max_axis, max_dim))
|
||||
|
||||
def __repr__(self):
|
||||
spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
|
||||
('shape=' + str(self.shape)) if self.shape else '',
|
||||
('ndim=' + str(self.ndim)) if self.ndim else '',
|
||||
('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
|
||||
('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
|
||||
('axes=' + str(self.axes)) if self.axes else '']
|
||||
return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'dtype': self.dtype,
|
||||
'shape': self.shape,
|
||||
'ndim': self.ndim,
|
||||
'max_ndim': self.max_ndim,
|
||||
'min_ndim': self.min_ndim,
|
||||
'axes': self.axes}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
|
||||
def to_tensor_shape(spec):
|
||||
"""Returns a tf.TensorShape object that matches the shape specifications.
|
||||
|
||||
If the InputSpec's shape or ndim is defined, this method will return a fully
|
||||
or partially-known shape. Otherwise, the returned TensorShape is None.
|
||||
|
||||
Args:
|
||||
spec: an InputSpec object.
|
||||
|
||||
Returns:
|
||||
a tf.TensorShape object
|
||||
"""
|
||||
if spec.ndim is None and spec.shape is None:
|
||||
return tensor_shape.TensorShape(None)
|
||||
elif spec.shape is not None:
|
||||
return tensor_shape.TensorShape(spec.shape)
|
||||
else:
|
||||
shape = [None] * spec.ndim
|
||||
for a in spec.axes:
|
||||
shape[a] = spec.axes[a] # Assume that axes is defined
|
||||
return tensor_shape.TensorShape(shape)
|
||||
|
||||
|
||||
def assert_input_compatibility(input_spec, inputs, layer_name):
|
||||
"""Checks compatibility between the layer and provided inputs.
|
||||
|
||||
This checks that the tensor(s) `inputs` verify the input assumptions
|
||||
of a layer (if any). If not, a clear and actional exception gets raised.
|
||||
|
||||
Arguments:
|
||||
input_spec: An InputSpec instance, list of InputSpec instances, a nested
|
||||
structure of InputSpec instances, or None.
|
||||
inputs: Input tensor, list of input tensors, or a nested structure of
|
||||
input tensors.
|
||||
layer_name: String, name of the layer (for error message formatting).
|
||||
|
||||
Raises:
|
||||
ValueError: in case of mismatch between
|
||||
the provided inputs and the expectations of the layer.
|
||||
"""
|
||||
if not input_spec:
|
||||
return
|
||||
|
||||
inputs = nest.flatten(inputs)
|
||||
input_spec = nest.flatten(input_spec)
|
||||
if len(inputs) != len(input_spec):
|
||||
raise ValueError('Layer ' + layer_name + ' expects ' +
|
||||
str(len(input_spec)) + ' inputs, '
|
||||
'but it received ' + str(len(inputs)) +
|
||||
' input tensors. Inputs received: ' + str(inputs))
|
||||
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
|
||||
if spec is None:
|
||||
continue
|
||||
|
||||
if (spec.ndim is not None or
|
||||
spec.min_ndim is not None or
|
||||
spec.max_ndim is not None):
|
||||
if x.shape.ndims is None:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'its rank is undefined, but the layer requires a '
|
||||
'defined rank.')
|
||||
|
||||
# Check ndim.
|
||||
if spec.ndim is not None:
|
||||
ndim = x.shape.ndims
|
||||
if ndim != spec.ndim:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
|
||||
str(ndim) + '. Full shape received: ' +
|
||||
str(x.shape.as_list()))
|
||||
if spec.max_ndim is not None:
|
||||
ndim = x.shape.ndims
|
||||
if ndim is not None and ndim > spec.max_ndim:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'expected max_ndim=' + str(spec.max_ndim) +
|
||||
', found ndim=' + str(ndim))
|
||||
if spec.min_ndim is not None:
|
||||
ndim = x.shape.ndims
|
||||
if ndim is not None and ndim < spec.min_ndim:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
': expected min_ndim=' + str(spec.min_ndim) +
|
||||
', found ndim=' + str(ndim) +
|
||||
'. Full shape received: ' +
|
||||
str(x.shape.as_list()))
|
||||
# Check dtype.
|
||||
if spec.dtype is not None:
|
||||
if x.dtype != spec.dtype:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'expected dtype=' + str(spec.dtype) +
|
||||
', found dtype=' + str(x.dtype))
|
||||
# Check specific shape axes.
|
||||
if spec.axes:
|
||||
shape = x.shape.as_list()
|
||||
if shape is not None:
|
||||
for axis, value in spec.axes.items():
|
||||
if hasattr(value, 'value'):
|
||||
value = value.value
|
||||
if value is not None and shape[int(axis)] not in {value, None}:
|
||||
raise ValueError(
|
||||
'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
|
||||
' incompatible with the layer: expected axis ' + str(axis) +
|
||||
' of input shape to have value ' + str(value) +
|
||||
' but received input with shape ' + str(shape))
|
||||
# Check shape.
|
||||
if spec.shape is not None:
|
||||
shape = x.shape.as_list()
|
||||
if shape is not None:
|
||||
for spec_dim, dim in zip(spec.shape, shape):
|
||||
if spec_dim is not None and dim is not None:
|
||||
if spec_dim != dim:
|
||||
raise ValueError('Input ' + str(input_index) +
|
||||
' is incompatible with layer ' + layer_name +
|
||||
': expected shape=' + str(spec.shape) +
|
||||
', found shape=' + str(shape))
|
||||
|
||||
|
||||
def to_tensor_spec(input_spec, default_dtype=None):
|
||||
"""Converts a Keras InputSpec object to a TensorSpec."""
|
||||
default_dtype = default_dtype or backend.floatx()
|
||||
if isinstance(input_spec, InputSpec):
|
||||
dtype = input_spec.dtype or default_dtype
|
||||
return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
|
||||
return tensor_spec.TensorSpec(None, default_dtype)
|
66
tensorflow/python/frozen_keras/engine/input_spec_test.py
Normal file
66
tensorflow/python/frozen_keras/engine/input_spec_test.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""InputSpec tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.frozen_keras.engine import input_spec
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class InputSpecTest(test.TestCase):
|
||||
|
||||
def test_axes_initialization(self):
|
||||
input_spec.InputSpec(shape=[1, None, 2, 3], axes={3: 5, '2': 2})
|
||||
with self.assertRaisesRegexp(ValueError, 'Axis 4 is greater than'):
|
||||
input_spec.InputSpec(shape=[1, None, 2, 3], axes={4: 5})
|
||||
with self.assertRaisesRegexp(TypeError, 'keys in axes must be integers'):
|
||||
input_spec.InputSpec(shape=[1, None, 2, 3], axes={'string': 5})
|
||||
|
||||
|
||||
class InputSpecToTensorShapeTest(test.TestCase):
|
||||
|
||||
def test_defined_shape(self):
|
||||
spec = input_spec.InputSpec(shape=[1, None, 2, 3])
|
||||
self.assertAllEqual(
|
||||
[1, None, 2, 3], input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
def test_defined_ndims(self):
|
||||
spec = input_spec.InputSpec(ndim=5)
|
||||
self.assertAllEqual(
|
||||
[None] * 5, input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
spec = input_spec.InputSpec(ndim=0)
|
||||
self.assertAllEqual(
|
||||
[], input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
spec = input_spec.InputSpec(ndim=3, axes={1: 3, -1: 2})
|
||||
self.assertAllEqual(
|
||||
[None, 3, 2], input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
def test_undefined_shapes(self):
|
||||
spec = input_spec.InputSpec(max_ndim=5)
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
|
||||
input_spec.to_tensor_shape(spec).as_list()
|
||||
|
||||
spec = input_spec.InputSpec(min_ndim=5, max_ndim=5)
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
|
||||
input_spec.to_tensor_shape(spec).as_list()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -54,10 +54,10 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.frozen_keras import constraints
|
||||
from tensorflow.python.frozen_keras import initializers
|
||||
from tensorflow.python.frozen_keras import regularizers
|
||||
from tensorflow.python.frozen_keras.engine import base_layer_utils
|
||||
from tensorflow.python.frozen_keras.engine import input_spec
|
||||
from tensorflow.python.frozen_keras.engine import node as node_module
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.engine import node as node_module
|
||||
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
|
189
tensorflow/python/frozen_keras/engine/node.py
Normal file
189
tensorflow/python/frozen_keras/engine/node.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
"""Contains the `Node` class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.frozen_keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class Node(object):
|
||||
"""A `Node` describes the connectivity between two layers.
|
||||
|
||||
Each time a layer is connected to some new input,
|
||||
a node is added to `layer._inbound_nodes`.
|
||||
Each time the output of a layer is used by another layer,
|
||||
a node is added to `layer._outbound_nodes`.
|
||||
|
||||
Arguments:
|
||||
outbound_layer: the layer that takes
|
||||
`input_tensors` and turns them into `output_tensors`
|
||||
(the node gets created when the `call`
|
||||
method of the layer was called).
|
||||
inbound_layers: a list of layers, the same length as `input_tensors`,
|
||||
the layers from where `input_tensors` originate.
|
||||
node_indices: a list of integers, the same length as `inbound_layers`.
|
||||
`node_indices[i]` is the origin node of `input_tensors[i]`
|
||||
(necessary since each inbound layer might have several nodes,
|
||||
e.g. if the layer is being shared with a different data stream).
|
||||
tensor_indices: a list of integers,
|
||||
the same length as `inbound_layers`.
|
||||
`tensor_indices[i]` is the index of `input_tensors[i]` within the
|
||||
output of the inbound layer
|
||||
(necessary since each inbound layer might
|
||||
have multiple tensor outputs, with each one being
|
||||
independently manipulable).
|
||||
input_tensors: list of input tensors.
|
||||
output_tensors: list of output tensors.
|
||||
arguments: dictionary of keyword arguments that were passed to the
|
||||
`call` method of the layer at the call that created the node.
|
||||
|
||||
`node_indices` and `tensor_indices` are basically fine-grained coordinates
|
||||
describing the origin of the `input_tensors`.
|
||||
|
||||
A node from layer A to layer B is added to:
|
||||
- A._outbound_nodes
|
||||
- B._inbound_nodes
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
outbound_layer,
|
||||
inbound_layers,
|
||||
node_indices,
|
||||
tensor_indices,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
arguments=None):
|
||||
# Layer instance (NOT a sequence)
|
||||
if isinstance(outbound_layer, (list, tuple, dict)):
|
||||
raise ValueError('`outbound_layer` should be a layer instance, '
|
||||
'not a list, tuple, or, dict.')
|
||||
|
||||
# this is the layer that takes a nested structure of input tensors
|
||||
# and turns them into a nested structure of output tensors.
|
||||
# the current node will be added to
|
||||
# the inbound_nodes of outbound_layer.
|
||||
self.outbound_layer = outbound_layer
|
||||
|
||||
# The following 3 properties describe where
|
||||
# the input tensors come from: which layers,
|
||||
# and for each layer, which node and which
|
||||
# tensor output of each node.
|
||||
|
||||
# Nested structure of layer instances.
|
||||
self.inbound_layers = inbound_layers
|
||||
# Nested structure of integers, 1:1 mapping with inbound_layers.
|
||||
self.node_indices = node_indices
|
||||
# Nested of integers, 1:1 mapping with inbound_layers.
|
||||
self.tensor_indices = tensor_indices
|
||||
|
||||
# Following 2 properties:
|
||||
# tensor inputs and outputs of outbound_layer.
|
||||
|
||||
# Nested structure of tensors. 1:1 mapping with inbound_layers.
|
||||
self.input_tensors = input_tensors
|
||||
# Nested structure of tensors, created by outbound_layer.call().
|
||||
self.output_tensors = output_tensors
|
||||
|
||||
# Following 2 properties: input and output shapes.
|
||||
|
||||
# Nested structure of shape tuples, shapes of input_tensors.
|
||||
self.input_shapes = nest.map_structure(backend.int_shape, input_tensors)
|
||||
# Nested structure of shape tuples, shapes of output_tensors.
|
||||
self.output_shapes = nest.map_structure(backend.int_shape, output_tensors)
|
||||
|
||||
# Optional keyword arguments to layer's `call`.
|
||||
self.arguments = arguments
|
||||
|
||||
# Create Keras History for any Keras Tensors in `arguments`.
|
||||
tensor_arguments = [
|
||||
t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor)
|
||||
]
|
||||
for tensor_argument in tensor_arguments:
|
||||
if base_layer_utils.needs_keras_history(
|
||||
tensor_argument, ignore_call_context=True):
|
||||
base_layer_utils.create_keras_history(tensor_argument)
|
||||
|
||||
# Add nodes to all layers involved.
|
||||
for layer in nest.flatten(inbound_layers):
|
||||
if layer is not None:
|
||||
# For compatibility with external Keras, we use the deprecated
|
||||
# accessor here.
|
||||
layer.outbound_nodes.append(self)
|
||||
# For compatibility with external Keras, we use the deprecated
|
||||
# accessor here.
|
||||
outbound_layer.inbound_nodes.append(self)
|
||||
|
||||
def iterate_inbound(self, include_arguments=False):
|
||||
"""Returns a list of tuples representing the inbound data.
|
||||
|
||||
Arguments:
|
||||
include_arguments: Whether to also iterate over any Keras Tensors
|
||||
passed as args, kwargs.
|
||||
|
||||
Returns:
|
||||
List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
|
||||
"""
|
||||
inputs_inbound = list(
|
||||
zip(
|
||||
nest.flatten(self.inbound_layers),
|
||||
nest.flatten(self.node_indices),
|
||||
nest.flatten(self.tensor_indices),
|
||||
nest.flatten(self.input_tensors)))
|
||||
|
||||
if include_arguments:
|
||||
keras_tensor_arguments = [
|
||||
kt for kt in nest.flatten(self.arguments)
|
||||
if hasattr(kt, '_keras_history')
|
||||
]
|
||||
|
||||
def _get_inbound(keras_tensor):
|
||||
kh = keras_tensor._keras_history
|
||||
return kh.layer, kh.node_index, kh.tensor_index, keras_tensor
|
||||
|
||||
arguments_inbound = nest.map_structure(_get_inbound,
|
||||
keras_tensor_arguments)
|
||||
|
||||
return inputs_inbound + arguments_inbound
|
||||
else:
|
||||
return inputs_inbound
|
||||
|
||||
def _get_all_node_dependencies(self):
|
||||
"""Returns all of the nodes this node immediately depends on."""
|
||||
node_deps = []
|
||||
for layer, node_index, _, _ in self.iterate_inbound():
|
||||
node_deps.append(layer._inbound_nodes[node_index])
|
||||
|
||||
for arg in nest.flatten(self.arguments):
|
||||
if isinstance(arg, ops.Tensor) and hasattr(arg, '_keras_history'):
|
||||
kh = arg._keras_history
|
||||
node_deps.append(kh.layer._inbound_nodes[kh.node_index])
|
||||
|
||||
return node_deps
|
||||
|
||||
def get_config(self):
|
||||
inbound_names = nest.map_structure(
|
||||
lambda layer: layer.name if layer else None, self.inbound_layers)
|
||||
return {
|
||||
'outbound_layer': self.outbound_layer.name,
|
||||
'inbound_layers': inbound_names,
|
||||
'node_indices': self.node_indices,
|
||||
'tensor_indices': self.tensor_indices
|
||||
}
|
Loading…
Reference in New Issue
Block a user