Remove all frozen copy of Keras code.
PiperOrigin-RevId: 300795615 Change-Id: Ibe8e69cddeb992aaa00a87da4c6543c8804f7b14
This commit is contained in:
parent
4fcd935d48
commit
9d297ebabc
@ -1,175 +0,0 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
#TODO(scottzhu): Cleanup all the deps to python/keras
|
||||
|
||||
py_library(
|
||||
name = "frozen_keras",
|
||||
deps = [
|
||||
":backend",
|
||||
":backend_config",
|
||||
":constraint",
|
||||
":initializers",
|
||||
":regularizers",
|
||||
"//tensorflow/python/frozen_keras/engine:legacy_base_layer",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "activations",
|
||||
srcs = ["activations.py"],
|
||||
deps = [
|
||||
":backend",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python/frozen_keras/utils:generic_utils",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "backend",
|
||||
srcs = ["backend.py"],
|
||||
deps = [
|
||||
":backend_config",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:config",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:ctc_ops",
|
||||
"//tensorflow/python:device",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:func_graph",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:image_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:logging_ops",
|
||||
"//tensorflow/python:map_fn",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:tensor_array_grad",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:training_lib",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
"//tensorflow/python/distribute:distribute_coordinator_context",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/eager:lift_to_graph",
|
||||
"//tensorflow/python/ops/ragged:ragged_concat_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "backend_config",
|
||||
srcs = ["backend_config.py"],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "constraint",
|
||||
srcs = ["constraints.py"],
|
||||
deps = [
|
||||
":backend",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/frozen_keras/utils:generic_utils",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "initializers",
|
||||
srcs = ["initializers.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:init_ops_v2",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/frozen_keras/utils:generic_utils",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "regularizers",
|
||||
srcs = ["regularizers.py"],
|
||||
deps = [
|
||||
":backend",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/frozen_keras/utils:generic_utils",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "backend_test",
|
||||
size = "medium",
|
||||
srcs = ["backend_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
":backend",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:config",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/frozen_keras/engine:base_layer_utils",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/layers:advanced_activations",
|
||||
"//tensorflow/python/keras/layers:normalization",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "backend_config_test",
|
||||
size = "medium",
|
||||
srcs = ["backend_config_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":backend",
|
||||
":backend_config",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
],
|
||||
)
|
@ -1,15 +0,0 @@
|
||||
# DO NOT USE
|
||||
|
||||
Everything under this package is for internal usage, and only serves a
|
||||
dependency from legacy TF v1 APIs that relies on Keras. Any active development
|
||||
should happen in third_party/tensorflow/python/keras instead.
|
||||
|
||||
## Background
|
||||
|
||||
In order to build a more modular Tensorflow and Keras, we decided to split the
|
||||
Keras code into its own repository. Having TensorFlow depend on
|
||||
Keras is a red flag as it is a reverse dependency. As some legacy TF V1 APIs
|
||||
are using Keras classes as base classes, like `Layer`, we decided to keep a copy
|
||||
of the trimmed Keras code to resolve the reverse dependency. This will also
|
||||
ensure the stability of the TF V1 API will be not affected by the active
|
||||
development of the Keras project.
|
@ -1,453 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Built-in activation functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.frozen_keras import backend as K
|
||||
from tensorflow.python.frozen_keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.frozen_keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
# b/123041942
|
||||
# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
|
||||
# layers, it gets serialized as 'softmax_v2' instead of 'softmax' as the
|
||||
# internal method name is returned in serialization. This results in errors in
|
||||
# model exporting and loading as Keras can't find any activation function with
|
||||
# the name of `softmax_v2`.
|
||||
|
||||
# This dict maps the activation function name from its v2 version to its
|
||||
# canonical name.
|
||||
_TF_ACTIVATIONS_V2 = {
|
||||
'softmax_v2': 'softmax',
|
||||
}
|
||||
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
"""Softmax converts a real vector to a vector of categorical probabilities.
|
||||
|
||||
The elements of the output vector are in range (0, 1) and sum to 1.
|
||||
|
||||
Each vector is handled independently. The `axis` argument sets which axis
|
||||
of the input the function is applied along.
|
||||
|
||||
Softmax is often used as the activation for the last
|
||||
layer of a classification network because the result could be interpreted as
|
||||
a probability distribution.
|
||||
|
||||
The softmax of each vector x is calculated by `exp(x)/tf.reduce_sum(exp(x))`.
|
||||
The input values in are the log-odds of the resulting probability.
|
||||
|
||||
Arguments:
|
||||
x : Input tensor.
|
||||
axis: Integer, axis along which the softmax normalization is applied.
|
||||
|
||||
Returns:
|
||||
Tensor, output of softmax transformation (all values are non-negative
|
||||
and sum to 1).
|
||||
|
||||
Raises:
|
||||
ValueError: In case `dim(x) == 1`.
|
||||
"""
|
||||
ndim = K.ndim(x)
|
||||
if ndim == 2:
|
||||
return nn.softmax(x)
|
||||
elif ndim > 2:
|
||||
e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True))
|
||||
s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
|
||||
return e / s
|
||||
else:
|
||||
raise ValueError('Cannot apply softmax to a tensor that is 1D. '
|
||||
'Received input: %s' % (x,))
|
||||
|
||||
|
||||
def elu(x, alpha=1.0):
|
||||
"""Exponential linear unit.
|
||||
|
||||
Arguments:
|
||||
x: Input tensor.
|
||||
alpha: A scalar, slope of negative section.
|
||||
|
||||
Returns:
|
||||
The exponential linear activation: `x` if `x > 0` and
|
||||
`alpha * (exp(x)-1)` if `x < 0`.
|
||||
|
||||
Reference:
|
||||
- [Fast and Accurate Deep Network Learning by Exponential
|
||||
Linear Units (ELUs)](https://arxiv.org/abs/1511.07289)
|
||||
"""
|
||||
return K.elu(x, alpha)
|
||||
|
||||
|
||||
def selu(x):
|
||||
"""Scaled Exponential Linear Unit (SELU).
|
||||
|
||||
The Scaled Exponential Linear Unit (SELU) activation function is:
|
||||
`scale * x` if `x > 0` and `scale * alpha * (exp(x) - 1)` if `x < 0`
|
||||
where `alpha` and `scale` are pre-defined constants
|
||||
(`alpha = 1.67326324`
|
||||
and `scale = 1.05070098`).
|
||||
The SELU activation function multiplies `scale` > 1 with the
|
||||
`[elu](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/activations/elu)`
|
||||
(Exponential Linear Unit (ELU)) to ensure a slope larger than one
|
||||
for positive net inputs.
|
||||
|
||||
The values of `alpha` and `scale` are
|
||||
chosen so that the mean and variance of the inputs are preserved
|
||||
between two consecutive layers as long as the weights are initialized
|
||||
correctly (see [`lecun_normal` initialization]
|
||||
(https://www.tensorflow.org/api_docs/python/tf/keras/initializers/lecun_normal))
|
||||
and the number of inputs is "large enough"
|
||||
(see references for more information).
|
||||
|
||||
![]https://cdn-images-1.medium.com/max/1600/1*m0e8lZU_Zrkh4ESfQkY2Pw.png
|
||||
(Courtesy: Blog on Towards DataScience at
|
||||
https://towardsdatascience.com/selu-make-fnns-great-again-snn-8d61526802a9)
|
||||
|
||||
Example Usage:
|
||||
|
||||
>>> n_classes = 10 #10-class problem
|
||||
>>> from tensorflow.python.keras.layers import Dense
|
||||
>>> model = tf.keras.Sequential()
|
||||
>>> model.add(Dense(64, kernel_initializer='lecun_normal',
|
||||
... activation='selu', input_shape=(28, 28, 1)))
|
||||
>>> model.add(Dense(32, kernel_initializer='lecun_normal',
|
||||
... activation='selu'))
|
||||
>>> model.add(Dense(16, kernel_initializer='lecun_normal',
|
||||
... activation='selu'))
|
||||
>>> model.add(Dense(n_classes, activation='softmax'))
|
||||
|
||||
Arguments:
|
||||
x: A tensor or variable to compute the activation function for.
|
||||
|
||||
Returns:
|
||||
The scaled exponential unit activation: `scale * elu(x, alpha)`.
|
||||
|
||||
# Note
|
||||
- To be used together with the initialization "[lecun_normal]
|
||||
(https://www.tensorflow.org/api_docs/python/tf/keras/initializers/lecun_normal)".
|
||||
- To be used together with the dropout variant "[AlphaDropout]
|
||||
(https://www.tensorflow.org/api_docs/python/tf/keras/layers/AlphaDropout)".
|
||||
|
||||
References:
|
||||
[Self-Normalizing Neural Networks (Klambauer et al, 2017)]
|
||||
(https://arxiv.org/abs/1706.02515)
|
||||
"""
|
||||
return nn.selu(x)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
"""Softplus activation function.
|
||||
|
||||
Arguments:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
The softplus activation: `log(exp(x) + 1)`.
|
||||
"""
|
||||
return nn.softplus(x)
|
||||
|
||||
|
||||
def softsign(x):
|
||||
"""Softsign activation function.
|
||||
|
||||
Arguments:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
The softsign activation: `x / (abs(x) + 1)`.
|
||||
"""
|
||||
return nn.softsign(x)
|
||||
|
||||
|
||||
def swish(x):
|
||||
"""Swish activation function.
|
||||
|
||||
Arguments:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
The swish activation applied to `x`.
|
||||
"""
|
||||
return nn.swish(x)
|
||||
|
||||
|
||||
def relu(x, alpha=0., max_value=None, threshold=0):
|
||||
"""Applies the rectified linear unit activation function.
|
||||
|
||||
With default values, this returns the standard ReLU activation:
|
||||
`max(x, 0)`, the element-wise maximum of 0 and the input tensor.
|
||||
|
||||
Modifying default parameters allows you to use non-zero thresholds,
|
||||
change the max value of the activation,
|
||||
and to use a non-zero multiple of the input for values below the threshold.
|
||||
|
||||
For example:
|
||||
|
||||
>>> foo = tf.constant([-10, -5, 0.0, 5, 10], dtype = tf.float32)
|
||||
>>> tf.keras.activations.relu(foo).numpy()
|
||||
array([ 0., 0., 0., 5., 10.], dtype=float32)
|
||||
>>> tf.keras.activations.relu(foo, alpha=0.5).numpy()
|
||||
array([-5. , -2.5, 0. , 5. , 10. ], dtype=float32)
|
||||
>>> tf.keras.activations.relu(foo, max_value=5).numpy()
|
||||
array([0., 0., 0., 5., 5.], dtype=float32)
|
||||
>>> tf.keras.activations.relu(foo, threshold=5).numpy()
|
||||
array([-0., -0., 0., 0., 10.], dtype=float32)
|
||||
|
||||
Arguments:
|
||||
x: Input `tensor` or `variable`.
|
||||
alpha: A `float` that governs the slope for values lower than the
|
||||
threshold.
|
||||
max_value: A `float` that sets the saturation threshold (the largest value
|
||||
the function will return).
|
||||
threshold: A `float` giving the threshold value of the activation function
|
||||
below which values will be damped or set to zero.
|
||||
|
||||
Returns:
|
||||
A `Tensor` representing the input tensor,
|
||||
transformed by the relu activation function.
|
||||
Tensor will be of the same shape and dtype of input `x`.
|
||||
"""
|
||||
return K.relu(x, alpha=alpha, max_value=max_value, threshold=threshold)
|
||||
|
||||
|
||||
def tanh(x):
|
||||
"""Hyperbolic tangent activation function.
|
||||
|
||||
For example:
|
||||
|
||||
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
|
||||
>>> b = tf.keras.activations.tanh(a)
|
||||
>>> b.numpy()
|
||||
array([-0.9950547, -0.7615942, 0. , 0.7615942, 0.9950547],
|
||||
dtype=float32)
|
||||
|
||||
Arguments:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor of same shape and dtype of input `x`, with tanh activation:
|
||||
`tanh(x) = sinh(x)/cosh(x) = ((exp(x) - exp(-x))/(exp(x) + exp(-x)))`.
|
||||
"""
|
||||
return nn.tanh(x)
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
"""Sigmoid activation function.
|
||||
|
||||
Applies the sigmoid activation function. The sigmoid function is defined as
|
||||
1 divided by (1 + exp(-x)). It's curve is like an "S" and is like a smoothed
|
||||
version of the Heaviside (Unit Step Function) function. For small values
|
||||
(<-5) the sigmoid returns a value close to zero and for larger values (>5)
|
||||
the result of the function gets close to 1.
|
||||
|
||||
Sigmoid is equivalent to a 2-element Softmax, where the second element is
|
||||
assumed to be zero.
|
||||
|
||||
For example:
|
||||
|
||||
>>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
|
||||
>>> b = tf.keras.activations.sigmoid(a)
|
||||
>>> b.numpy() >= 0.0
|
||||
array([ True, True, True, True, True])
|
||||
|
||||
Arguments:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor with the sigmoid activation: `(1.0 / (1.0 + exp(-x)))`.
|
||||
Tensor will be of same shape and dtype of input `x`.
|
||||
"""
|
||||
return nn.sigmoid(x)
|
||||
|
||||
|
||||
def exponential(x):
|
||||
"""Exponential activation function.
|
||||
|
||||
For example:
|
||||
|
||||
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
|
||||
>>> b = tf.keras.activations.exponential(a)
|
||||
>>> b.numpy()
|
||||
array([ 0.04978707, 0.36787945, 1. , 2.7182817 , 20.085537 ],
|
||||
dtype=float32)
|
||||
|
||||
Arguments:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor with exponential activation: `exp(x)`. Tensor will be of same
|
||||
shape and dtype of input `x`.
|
||||
"""
|
||||
return math_ops.exp(x)
|
||||
|
||||
|
||||
def hard_sigmoid(x):
|
||||
"""Hard sigmoid activation function.
|
||||
|
||||
Faster to compute than sigmoid activation.
|
||||
|
||||
For example:
|
||||
|
||||
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
|
||||
>>> b = tf.keras.activations.hard_sigmoid(a)
|
||||
>>> b.numpy()
|
||||
array([0. , 0.3, 0.5, 0.7, 1. ], dtype=float32)
|
||||
|
||||
Arguments:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
The hard sigmoid activation:
|
||||
|
||||
- `0` if `x < -2.5`
|
||||
- `1` if `x > 2.5`
|
||||
- `0.2 * x + 0.5` if `-2.5 <= x <= 2.5`.
|
||||
"""
|
||||
return K.hard_sigmoid(x)
|
||||
|
||||
|
||||
def linear(x):
|
||||
"""Linear activation function.
|
||||
|
||||
For example:
|
||||
|
||||
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
|
||||
>>> b = tf.keras.activations.linear(a)
|
||||
>>> b.numpy()
|
||||
array([-3., -1., 0., 1., 3.], dtype=float32)
|
||||
|
||||
Arguments:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
the input unmodified.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
def serialize(activation):
|
||||
"""Returns name attribute (`__name__`) of function.
|
||||
|
||||
Arguments:
|
||||
activation : Function
|
||||
|
||||
Returns:
|
||||
String denoting the name attribute of the input function
|
||||
|
||||
For example:
|
||||
|
||||
>>> tf.keras.activations.serialize(tf.keras.activations.tanh)
|
||||
'tanh'
|
||||
>>> tf.keras.activations.serialize(tf.keras.activations.sigmoid)
|
||||
'sigmoid'
|
||||
>>> tf.keras.activations.serialize('abcd')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: ('Cannot serialize', 'abcd')
|
||||
|
||||
Raises:
|
||||
ValueError: The input function is not a valid one.
|
||||
"""
|
||||
if (hasattr(activation, '__name__') and
|
||||
activation.__name__ in _TF_ACTIVATIONS_V2):
|
||||
return _TF_ACTIVATIONS_V2[activation.__name__]
|
||||
return serialize_keras_object(activation)
|
||||
|
||||
|
||||
def deserialize(name, custom_objects=None):
|
||||
"""Returns activation function denoted by input string.
|
||||
|
||||
Arguments:
|
||||
x : String
|
||||
|
||||
Returns:
|
||||
TensorFlow Activation function denoted by input string.
|
||||
|
||||
For example:
|
||||
|
||||
>>> tf.keras.activations.deserialize('linear')
|
||||
<function linear at 0x1239596a8>
|
||||
>>> tf.keras.activations.deserialize('sigmoid')
|
||||
<function sigmoid at 0x123959510>
|
||||
>>> tf.keras.activations.deserialize('abcd')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Unknown activation function:abcd
|
||||
|
||||
Args:
|
||||
name: The name of the activation function.
|
||||
custom_objects: A {name:value} dictionary for activations not build into
|
||||
keras.
|
||||
|
||||
Raises:
|
||||
ValueError: `Unknown activation function` if the input string does not
|
||||
denote any defined Tensorflow activation function.
|
||||
"""
|
||||
return deserialize_keras_object(
|
||||
name,
|
||||
module_objects=globals(),
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='activation function')
|
||||
|
||||
|
||||
def get(identifier):
|
||||
"""Returns function.
|
||||
|
||||
Arguments:
|
||||
identifier: Function or string
|
||||
|
||||
Returns:
|
||||
Activation function denoted by input:
|
||||
- `Linear activation function` if input is `None`.
|
||||
- Function corresponding to the input string or input function.
|
||||
|
||||
For example:
|
||||
|
||||
>>> tf.keras.activations.get('softmax')
|
||||
<function softmax at 0x1222a3d90>
|
||||
>>> tf.keras.activations.get(tf.keras.activations.softmax)
|
||||
<function softmax at 0x1222a3d90>
|
||||
>>> tf.keras.activations.get(None)
|
||||
<function linear at 0x1239596a8>
|
||||
>>> tf.keras.activations.get(abs)
|
||||
<built-in function abs>
|
||||
>>> tf.keras.activations.get('abcd')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Unknown activation function:abcd
|
||||
|
||||
Raises:
|
||||
ValueError: Input is an unknown function or string, i.e., the input does
|
||||
not denote any defined function.
|
||||
"""
|
||||
if identifier is None:
|
||||
return linear
|
||||
if isinstance(identifier, six.string_types):
|
||||
identifier = str(identifier)
|
||||
return deserialize(identifier)
|
||||
elif callable(identifier):
|
||||
return identifier
|
||||
elif isinstance(identifier, dict):
|
||||
return deserialize_keras_object(
|
||||
identifier, printable_module_name='activation')
|
||||
else:
|
||||
raise TypeError(
|
||||
'Could not interpret activation function identifier: {}'.format(
|
||||
repr(identifier)))
|
File diff suppressed because it is too large
Load Diff
@ -1,140 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras backend config API."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
# The type of float to use throughout a session.
|
||||
_FLOATX = 'float32'
|
||||
|
||||
# Epsilon fuzz factor used throughout the codebase.
|
||||
_EPSILON = 1e-7
|
||||
|
||||
# Default image data format, one of "channels_last", "channels_first".
|
||||
_IMAGE_DATA_FORMAT = 'channels_last'
|
||||
|
||||
|
||||
def epsilon():
|
||||
"""Returns the value of the fuzz factor used in numeric expressions.
|
||||
|
||||
Returns:
|
||||
A float.
|
||||
|
||||
Example:
|
||||
>>> tf.keras.backend.epsilon()
|
||||
1e-07
|
||||
"""
|
||||
return _EPSILON
|
||||
|
||||
|
||||
def set_epsilon(value):
|
||||
"""Sets the value of the fuzz factor used in numeric expressions.
|
||||
|
||||
Arguments:
|
||||
value: float. New value of epsilon.
|
||||
|
||||
Example:
|
||||
>>> tf.keras.backend.epsilon()
|
||||
1e-07
|
||||
>>> tf.keras.backend.set_epsilon(1e-5)
|
||||
>>> tf.keras.backend.epsilon()
|
||||
1e-05
|
||||
>>> tf.keras.backend.set_epsilon(1e-7)
|
||||
"""
|
||||
global _EPSILON
|
||||
_EPSILON = value
|
||||
|
||||
|
||||
def floatx():
|
||||
"""Returns the default float type, as a string.
|
||||
|
||||
E.g. `'float16'`, `'float32'`, `'float64'`.
|
||||
|
||||
Returns:
|
||||
String, the current default float type.
|
||||
|
||||
Example:
|
||||
>>> tf.keras.backend.floatx()
|
||||
'float32'
|
||||
"""
|
||||
return _FLOATX
|
||||
|
||||
|
||||
def set_floatx(value):
|
||||
"""Sets the default float type.
|
||||
|
||||
Note: It is not recommended to set this to float16 for training, as this will
|
||||
likely cause numeric stability issues. Instead, mixed precision, which is
|
||||
using a mix of float16 and float32, can be used by calling
|
||||
`tf.keras.mixed_precision.experimental.set_policy('mixed_float16')`. See the
|
||||
[mixed precision
|
||||
guide](https://www.tensorflow.org/guide/keras/mixed_precision) for details.
|
||||
|
||||
Arguments:
|
||||
value: String; `'float16'`, `'float32'`, or `'float64'`.
|
||||
|
||||
Example:
|
||||
>>> tf.keras.backend.floatx()
|
||||
'float32'
|
||||
>>> tf.keras.backend.set_floatx('float64')
|
||||
>>> tf.keras.backend.floatx()
|
||||
'float64'
|
||||
>>> tf.keras.backend.set_floatx('float32')
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid value.
|
||||
"""
|
||||
global _FLOATX
|
||||
if value not in {'float16', 'float32', 'float64'}:
|
||||
raise ValueError('Unknown floatx type: ' + str(value))
|
||||
_FLOATX = str(value)
|
||||
|
||||
|
||||
def image_data_format():
|
||||
"""Returns the default image data format convention.
|
||||
|
||||
Returns:
|
||||
A string, either `'channels_first'` or `'channels_last'`
|
||||
|
||||
Example:
|
||||
>>> tf.keras.backend.image_data_format()
|
||||
'channels_last'
|
||||
"""
|
||||
return _IMAGE_DATA_FORMAT
|
||||
|
||||
|
||||
def set_image_data_format(data_format):
|
||||
"""Sets the value of the image data format convention.
|
||||
|
||||
Arguments:
|
||||
data_format: string. `'channels_first'` or `'channels_last'`.
|
||||
|
||||
Example:
|
||||
>>> tf.keras.backend.image_data_format()
|
||||
'channels_last'
|
||||
>>> tf.keras.backend.set_image_data_format('channels_first')
|
||||
>>> tf.keras.backend.image_data_format()
|
||||
'channels_first'
|
||||
>>> tf.keras.backend.set_image_data_format('channels_last')
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid `data_format` value.
|
||||
"""
|
||||
global _IMAGE_DATA_FORMAT
|
||||
if data_format not in {'channels_last', 'channels_first'}:
|
||||
raise ValueError('Unknown data_format: ' + str(data_format))
|
||||
_IMAGE_DATA_FORMAT = str(data_format)
|
@ -1,55 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for backend_config."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.frozen_keras import backend
|
||||
from tensorflow.python.frozen_keras import backend_config
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class BackendConfigTest(test.TestCase):
|
||||
|
||||
def test_backend(self):
|
||||
self.assertEqual(backend.backend(), 'tensorflow')
|
||||
|
||||
def test_epsilon(self):
|
||||
epsilon = 1e-2
|
||||
backend_config.set_epsilon(epsilon)
|
||||
self.assertEqual(backend_config.epsilon(), epsilon)
|
||||
backend_config.set_epsilon(1e-7)
|
||||
self.assertEqual(backend_config.epsilon(), 1e-7)
|
||||
|
||||
def test_floatx(self):
|
||||
floatx = 'float64'
|
||||
backend_config.set_floatx(floatx)
|
||||
self.assertEqual(backend_config.floatx(), floatx)
|
||||
backend_config.set_floatx('float32')
|
||||
self.assertEqual(backend_config.floatx(), 'float32')
|
||||
|
||||
def test_image_data_format(self):
|
||||
image_data_format = 'channels_first'
|
||||
backend_config.set_image_data_format(image_data_format)
|
||||
self.assertEqual(backend_config.image_data_format(), image_data_format)
|
||||
backend_config.set_image_data_format('channels_last')
|
||||
self.assertEqual(backend_config.image_data_format(), 'channels_last')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
File diff suppressed because it is too large
Load Diff
@ -1,282 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=invalid-name
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
"""Constraints: functions that impose constraints on weight values."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.frozen_keras import backend as K
|
||||
from tensorflow.python.frozen_keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.frozen_keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class Constraint(object):
|
||||
|
||||
def __call__(self, w):
|
||||
return w
|
||||
|
||||
def get_config(self):
|
||||
return {}
|
||||
|
||||
|
||||
class MaxNorm(Constraint):
|
||||
"""MaxNorm weight constraint.
|
||||
|
||||
Constrains the weights incident to each hidden unit
|
||||
to have a norm less than or equal to a desired value.
|
||||
|
||||
Arguments:
|
||||
m: the maximum norm for the incoming weights.
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, max_value=2, axis=0):
|
||||
self.max_value = max_value
|
||||
self.axis = axis
|
||||
|
||||
def __call__(self, w):
|
||||
norms = K.sqrt(
|
||||
math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
|
||||
desired = K.clip(norms, 0, self.max_value)
|
||||
return w * (desired / (K.epsilon() + norms))
|
||||
|
||||
def get_config(self):
|
||||
return {'max_value': self.max_value, 'axis': self.axis}
|
||||
|
||||
|
||||
class NonNeg(Constraint):
|
||||
"""Constrains the weights to be non-negative.
|
||||
"""
|
||||
|
||||
def __call__(self, w):
|
||||
return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx())
|
||||
|
||||
|
||||
class UnitNorm(Constraint):
|
||||
"""Constrains the weights incident to each hidden unit to have unit norm.
|
||||
|
||||
Arguments:
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
"""
|
||||
|
||||
def __init__(self, axis=0):
|
||||
self.axis = axis
|
||||
|
||||
def __call__(self, w):
|
||||
return w / (
|
||||
K.epsilon() + K.sqrt(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.square(w), axis=self.axis, keepdims=True)))
|
||||
|
||||
def get_config(self):
|
||||
return {'axis': self.axis}
|
||||
|
||||
|
||||
class MinMaxNorm(Constraint):
|
||||
"""MinMaxNorm weight constraint.
|
||||
|
||||
Constrains the weights incident to each hidden unit
|
||||
to have the norm between a lower bound and an upper bound.
|
||||
|
||||
Arguments:
|
||||
min_value: the minimum norm for the incoming weights.
|
||||
max_value: the maximum norm for the incoming weights.
|
||||
rate: rate for enforcing the constraint: weights will be
|
||||
rescaled to yield
|
||||
`(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
|
||||
Effectively, this means that rate=1.0 stands for strict
|
||||
enforcement of the constraint, while rate<1.0 means that
|
||||
weights will be rescaled at each step to slowly move
|
||||
towards a value inside the desired interval.
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
"""
|
||||
|
||||
def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.rate = rate
|
||||
self.axis = axis
|
||||
|
||||
def __call__(self, w):
|
||||
norms = K.sqrt(
|
||||
math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
|
||||
desired = (
|
||||
self.rate * K.clip(norms, self.min_value, self.max_value) +
|
||||
(1 - self.rate) * norms)
|
||||
return w * (desired / (K.epsilon() + norms))
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'min_value': self.min_value,
|
||||
'max_value': self.max_value,
|
||||
'rate': self.rate,
|
||||
'axis': self.axis
|
||||
}
|
||||
|
||||
|
||||
class RadialConstraint(Constraint):
|
||||
"""Constrains `Conv2D` kernel weights to be the same for each radius.
|
||||
|
||||
For example, the desired output for the following 4-by-4 kernel::
|
||||
|
||||
```
|
||||
kernel = [[v_00, v_01, v_02, v_03],
|
||||
[v_10, v_11, v_12, v_13],
|
||||
[v_20, v_21, v_22, v_23],
|
||||
[v_30, v_31, v_32, v_33]]
|
||||
```
|
||||
|
||||
is this::
|
||||
|
||||
```
|
||||
kernel = [[v_11, v_11, v_11, v_11],
|
||||
[v_11, v_33, v_33, v_11],
|
||||
[v_11, v_33, v_33, v_11],
|
||||
[v_11, v_11, v_11, v_11]]
|
||||
```
|
||||
|
||||
This constraint can be applied to any `Conv2D` layer version, including
|
||||
`Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
|
||||
`"channels_first"` data format. The method assumes the weight tensor is of
|
||||
shape `(rows, cols, input_depth, output_depth)`.
|
||||
"""
|
||||
|
||||
def __call__(self, w):
|
||||
w_shape = w.shape
|
||||
if w_shape.rank is None or w_shape.rank != 4:
|
||||
raise ValueError(
|
||||
'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
|
||||
|
||||
height, width, channels, kernels = w_shape
|
||||
w = K.reshape(w, (height, width, channels * kernels))
|
||||
# TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch
|
||||
# is supported.
|
||||
w = K.map_fn(
|
||||
self._kernel_constraint,
|
||||
K.stack(array_ops.unstack(w, axis=-1), axis=0))
|
||||
return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1),
|
||||
(height, width, channels, kernels))
|
||||
|
||||
def _kernel_constraint(self, kernel):
|
||||
"""Radially constraints a kernel with shape (height, width, channels)."""
|
||||
padding = K.constant([[1, 1], [1, 1]], dtype='int32')
|
||||
|
||||
kernel_shape = K.shape(kernel)[0]
|
||||
start = K.cast(kernel_shape / 2, 'int32')
|
||||
|
||||
kernel_new = K.switch(
|
||||
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
|
||||
lambda: kernel[start - 1:start, start - 1:start],
|
||||
lambda: kernel[start - 1:start, start - 1:start] + K.zeros( # pylint: disable=g-long-lambda
|
||||
(2, 2), dtype=kernel.dtype))
|
||||
index = K.switch(
|
||||
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
|
||||
lambda: K.constant(0, dtype='int32'),
|
||||
lambda: K.constant(1, dtype='int32'))
|
||||
while_condition = lambda index, *args: K.less(index, start)
|
||||
|
||||
def body_fn(i, array):
|
||||
return i + 1, array_ops.pad(
|
||||
array,
|
||||
padding,
|
||||
constant_values=kernel[start + i, start + i])
|
||||
|
||||
_, kernel_new = control_flow_ops.while_loop(
|
||||
while_condition,
|
||||
body_fn,
|
||||
[index, kernel_new],
|
||||
shape_invariants=[index.get_shape(),
|
||||
tensor_shape.TensorShape([None, None])])
|
||||
return kernel_new
|
||||
|
||||
|
||||
# Aliases.
|
||||
|
||||
max_norm = MaxNorm
|
||||
non_neg = NonNeg
|
||||
unit_norm = UnitNorm
|
||||
min_max_norm = MinMaxNorm
|
||||
radial_constraint = RadialConstraint
|
||||
|
||||
# Legacy aliases.
|
||||
maxnorm = max_norm
|
||||
nonneg = non_neg
|
||||
unitnorm = unit_norm
|
||||
|
||||
|
||||
def serialize(constraint):
|
||||
return serialize_keras_object(constraint)
|
||||
|
||||
|
||||
def deserialize(config, custom_objects=None):
|
||||
return deserialize_keras_object(
|
||||
config,
|
||||
module_objects=globals(),
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='constraint')
|
||||
|
||||
|
||||
def get(identifier):
|
||||
if identifier is None:
|
||||
return None
|
||||
if isinstance(identifier, dict):
|
||||
return deserialize(identifier)
|
||||
elif isinstance(identifier, six.string_types):
|
||||
config = {'class_name': str(identifier), 'config': {}}
|
||||
return deserialize(config)
|
||||
elif callable(identifier):
|
||||
return identifier
|
||||
else:
|
||||
raise ValueError('Could not interpret constraint identifier: ' +
|
||||
str(identifier))
|
@ -1,151 +0,0 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
#TODO(scottzhu): Cleanup all the deps to python/keras
|
||||
|
||||
py_library(
|
||||
name = "legacy_base_layer",
|
||||
srcs = ["legacy_base_layer.py"],
|
||||
deps = [
|
||||
":base_layer_utils",
|
||||
":input_spec",
|
||||
":node",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:auto_control_deps",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:func_graph",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/autograph/core",
|
||||
"//tensorflow/python/autograph/impl",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:execute",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/frozen_keras:backend",
|
||||
"//tensorflow/python/frozen_keras:constraint",
|
||||
"//tensorflow/python/frozen_keras:initializers",
|
||||
"//tensorflow/python/frozen_keras:regularizers",
|
||||
"//tensorflow/python/frozen_keras/utils:generic_utils",
|
||||
"//tensorflow/python/frozen_keras/utils:layer_utils",
|
||||
"//tensorflow/python/frozen_keras/utils:tf_utils",
|
||||
"//tensorflow/python/keras:metrics",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
"//tensorflow/python/training/tracking:layer_utils",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "base_layer_utils",
|
||||
srcs = ["base_layer_utils.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:control_flow_v2_func_graphs",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:init_ops_v2",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:distribute_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/frozen_keras:backend",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "input_spec",
|
||||
srcs = ["input_spec.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/frozen_keras:backend",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "node",
|
||||
srcs = ["node.py"],
|
||||
deps = [
|
||||
":base_layer_utils",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/frozen_keras:backend",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "legacy_base_layer_test",
|
||||
size = "medium",
|
||||
srcs = ["legacy_base_layer_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
deps = [
|
||||
":legacy_base_layer",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "base_layer_utils_test",
|
||||
srcs = ["base_layer_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
deps = [
|
||||
":base_layer_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "input_spec_test",
|
||||
size = "small",
|
||||
srcs = ["input_spec_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
deps = [
|
||||
":input_spec",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
@ -1,781 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains private utilities used mainly by the base Layer class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.frozen_keras import backend
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_v2_func_graphs
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.training.tracking import base as tracking
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
_call_context = threading.local()
|
||||
|
||||
|
||||
def make_variable(name,
|
||||
shape=None,
|
||||
dtype=dtypes.float32,
|
||||
initializer=None,
|
||||
trainable=None,
|
||||
caching_device=None,
|
||||
validate_shape=True,
|
||||
constraint=None,
|
||||
use_resource=None,
|
||||
collections=None,
|
||||
synchronization=tf_variables.VariableSynchronization.AUTO,
|
||||
aggregation=tf_variables.VariableAggregation.NONE,
|
||||
partitioner=None): # pylint: disable=unused-argument
|
||||
"""Temporary util to create a variable (relies on `variable_scope.variable`).
|
||||
|
||||
Some reuse-related technicalities prevent us from using
|
||||
`variable_scope.get_variable()` directly, so we use a subcomponent
|
||||
that has fewer constraints (`variable_scope.variable()`).
|
||||
|
||||
In the longer term, it seems like a similar "default variable creator" method
|
||||
should exist in `Trackable` instead. When this happens, we can get
|
||||
rid of this temporary solution.
|
||||
|
||||
TODO(fchollet): remove this method when no longer needed.
|
||||
|
||||
Arguments:
|
||||
name: Variable name.
|
||||
shape: Variable shape.
|
||||
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
|
||||
initializer: Initializer instance (callable).
|
||||
trainable: Whether the variable should be part of the layer's
|
||||
"trainable_variables" (e.g. variables, biases)
|
||||
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
|
||||
Note, if the current variable scope is marked as non-trainable
|
||||
then this parameter is ignored and any added variables are also
|
||||
marked as non-trainable. `trainable` defaults to `True` unless
|
||||
`synchronization` is set to `ON_READ`.
|
||||
caching_device: Passed to `tf.Variable`.
|
||||
validate_shape: Passed to `tf.Variable`.
|
||||
constraint: Constraint instance (callable).
|
||||
use_resource: Whether to use a `ResourceVariable`.
|
||||
collections: List of graph collections keys. The new variable is added to
|
||||
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
|
||||
synchronization: Indicates when a distributed a variable will be
|
||||
aggregated. Accepted values are constants defined in the class
|
||||
`tf.VariableSynchronization`. By default the synchronization is set to
|
||||
`AUTO` and the current `DistributionStrategy` chooses
|
||||
when to synchronize. If `synchronization` is set to `ON_READ`,
|
||||
`trainable` must not be set to `True`.
|
||||
aggregation: Indicates how a distributed variable will be aggregated.
|
||||
Accepted values are constants defined in the class
|
||||
`tf.VariableAggregation`.
|
||||
partitioner: Not handled at this time.
|
||||
|
||||
Returns:
|
||||
Variable instance.
|
||||
"""
|
||||
initializing_from_value = False
|
||||
if initializer is not None and not callable(initializer):
|
||||
initializing_from_value = True
|
||||
|
||||
if initializing_from_value:
|
||||
init_val = initializer
|
||||
variable_dtype = None
|
||||
else:
|
||||
# Instantiate initializer if provided initializer is a type object.
|
||||
if isinstance(
|
||||
initializer,
|
||||
(type(init_ops.Initializer), type(init_ops_v2.Initializer))):
|
||||
initializer = initializer()
|
||||
init_val = lambda: initializer(shape, dtype=dtype)
|
||||
variable_dtype = dtype.base_dtype
|
||||
if use_resource is None:
|
||||
use_resource = True
|
||||
|
||||
# TODO(apassos,rohanj) figure out how to remove collections from here so we
|
||||
# can remove the V1.
|
||||
variable_shape = tensor_shape.TensorShape(shape)
|
||||
return tf_variables.VariableV1(
|
||||
initial_value=init_val,
|
||||
name=name,
|
||||
trainable=trainable,
|
||||
caching_device=caching_device,
|
||||
dtype=variable_dtype,
|
||||
validate_shape=validate_shape,
|
||||
constraint=constraint,
|
||||
use_resource=use_resource,
|
||||
collections=collections,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation,
|
||||
shape=variable_shape if variable_shape else None)
|
||||
|
||||
|
||||
def collect_previous_mask(input_tensors):
|
||||
"""Retrieves the output mask(s) of the previous node.
|
||||
|
||||
Arguments:
|
||||
input_tensors: An arbitrary structure of Tensors.
|
||||
|
||||
Returns:
|
||||
A mask tensor or list of mask tensors.
|
||||
"""
|
||||
|
||||
def _collect_previous_mask(x):
|
||||
return getattr(x, '_keras_mask', None)
|
||||
|
||||
return nest.map_structure(_collect_previous_mask, input_tensors)
|
||||
|
||||
|
||||
def have_all_keras_metadata(tensors):
|
||||
return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
|
||||
|
||||
|
||||
def generate_placeholders_from_shape(shape):
|
||||
return array_ops.placeholder(shape=shape, dtype=backend.floatx())
|
||||
|
||||
|
||||
def create_keras_history(tensors):
|
||||
"""Wraps TensorFlow Operations for compatibility with the Functional API.
|
||||
|
||||
This method checks to see if a Tensor in `tensors` is missing Keras metadata
|
||||
and has its origin in a Keras `Input` Layer. If so, this method will replace
|
||||
the raw TensorFlow Operations that created this tensor with
|
||||
`TensorFlowOpLayer` instances that create identical operations.
|
||||
|
||||
Any Tensors not originating from a Keras `Input` Layer will be treated as
|
||||
constants when constructing `TensorFlowOpLayer` instances.
|
||||
|
||||
Arguments:
|
||||
tensors: A structure of Tensors, some of which come from raw TensorFlow
|
||||
operations and need to have Keras metadata assigned to them.
|
||||
|
||||
Returns:
|
||||
created_layers: List. The `TensorFlowOpLayer` instances created to wrap
|
||||
the raw Tensorflow operations.
|
||||
"""
|
||||
_, created_layers = _create_keras_history_helper(tensors, set(), [])
|
||||
return created_layers
|
||||
|
||||
|
||||
def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
||||
"""Helper method for `create_keras_history`.
|
||||
|
||||
Arguments:
|
||||
tensors: A structure of Tensors for which to create Keras metadata.
|
||||
processed_ops: Set. TensorFlow operations that have already been wrapped in
|
||||
`TensorFlowOpLayer` instances.
|
||||
created_layers: List. The `TensorFlowOpLayer` instances created.
|
||||
|
||||
Returns:
|
||||
Tuple. First element is the updated set of TensorFlow Operations that
|
||||
have been wrapped in `TensorFlowOpLayer` instances. Second element is
|
||||
a list of the `TensorFlowOpLayer` instances created.
|
||||
"""
|
||||
# Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
|
||||
# Cannot be imported at top because of circular dependencies.
|
||||
# TODO(omalleyt): Resolve circular dependency.
|
||||
from tensorflow.python.frozen_keras.engine import legacy_base_layer as base_layer # pylint: disable=g-import-not-at-top
|
||||
tensor_list = nest.flatten(tensors)
|
||||
for tensor in tensor_list:
|
||||
if getattr(tensor, '_keras_history', None) is not None:
|
||||
continue
|
||||
op = tensor.op # The Op that created this Tensor.
|
||||
if op not in processed_ops:
|
||||
if op.type.startswith('Sparse'):
|
||||
lambda_example = """
|
||||
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
|
||||
output = tf.keras.layers.Lambda(weights_mult)(input)
|
||||
"""
|
||||
raise ValueError(
|
||||
'Sparse ops are not supported with functional models with built-in '
|
||||
'layer wrapping. Please wrap the sparse ops in a Lambda layer like'
|
||||
': \n{lambda_example}\n'.format(lambda_example=lambda_example))
|
||||
|
||||
# Recursively set `_keras_history`.
|
||||
op_inputs = list(op.inputs)
|
||||
constants = {}
|
||||
layer_inputs = []
|
||||
for i, op_input in enumerate(op_inputs):
|
||||
if uses_keras_history(op_input):
|
||||
layer_inputs.append(op_input)
|
||||
else:
|
||||
# Treat any value not originating from a `keras.Input` as
|
||||
# a constant. Variables cannot be supported.
|
||||
ds_with_session = (
|
||||
distribution_strategy_context.in_cross_replica_context() and
|
||||
not ops.executing_eagerly_outside_functions())
|
||||
using_xla = control_flow_util.GraphOrParentsInXlaContext(
|
||||
ops.get_default_graph())
|
||||
if ds_with_session or using_xla:
|
||||
# In Legacy Graph mode, evaluating here makes Session be
|
||||
# configured improperly. The downside of this is that saving
|
||||
# via `get_config` breaks, but SavedModel still works.
|
||||
constants[i] = op_input
|
||||
else:
|
||||
with ops.init_scope():
|
||||
constants[i] = backend.function([], op_input)([])
|
||||
layer_inputs = unnest_if_single_tensor(layer_inputs)
|
||||
processed_ops, created_layers = _create_keras_history_helper(
|
||||
layer_inputs, processed_ops, created_layers)
|
||||
name = op.name
|
||||
node_def = op.node_def.SerializeToString()
|
||||
op_layer = base_layer.TensorFlowOpLayer(
|
||||
node_def, constants=constants, name=name)
|
||||
created_layers.append(op_layer)
|
||||
op_layer._add_inbound_node( # pylint: disable=protected-access
|
||||
layer_inputs, op.outputs)
|
||||
processed_ops.update([op])
|
||||
return processed_ops, created_layers
|
||||
|
||||
|
||||
def unnest_if_single_tensor(input_tensors):
|
||||
# Preserve compatibility with older configs
|
||||
flat_input_tensors = nest.flatten(input_tensors)
|
||||
# If this is a single element but not a dict, unwrap. If this is a dict,
|
||||
# assume the first layer expects a dict (as is the case with a
|
||||
# DenseFeatures layer); pass through.
|
||||
if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
|
||||
input_tensors = flat_input_tensors[0]
|
||||
return input_tensors
|
||||
|
||||
|
||||
def needs_keras_history(tensors, ignore_call_context=False):
|
||||
"""Check if any Tensors need to be wrapped in TensorFlowOpLayers.
|
||||
|
||||
This will never return True inside a sublayer, because sublayers
|
||||
do not need to create Keras History. Otherwise, this returns True
|
||||
if one or more of `tensors` originates from a `keras.Input` and
|
||||
does not have `_keras_history` set.
|
||||
|
||||
Arguments:
|
||||
tensors: An arbitrary nested structure of Tensors.
|
||||
ignore_call_context: Whether to ignore the check of if currently
|
||||
outside of a `call` context. This is `True` when creating
|
||||
KerasHistory inside `Node`, where we always know that Tensors
|
||||
are being used with the Functional API.
|
||||
|
||||
Returns:
|
||||
Bool, whether at least one Tensor needs to be wrapped.
|
||||
"""
|
||||
input_tensors = nest.flatten(tensors)
|
||||
if call_context().in_call and not ignore_call_context:
|
||||
return False
|
||||
if all(
|
||||
getattr(tensor, '_keras_history', None) is not None
|
||||
for tensor in input_tensors):
|
||||
# KerasHistory already set.
|
||||
return False
|
||||
return uses_keras_history(tensors)
|
||||
|
||||
|
||||
def is_in_keras_graph():
|
||||
"""Returns if currently executing inside of a Keras graph."""
|
||||
return call_context().in_keras_graph
|
||||
|
||||
|
||||
def is_in_eager_or_tf_function():
|
||||
"""Returns if in eager mode or inside of a tf.function."""
|
||||
return context.executing_eagerly() or is_in_tf_function()
|
||||
|
||||
|
||||
def is_in_tf_function():
|
||||
"""Returns if inside of a tf.function."""
|
||||
# Check if running in V1 graph mode.
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
return False
|
||||
if not ops.inside_function():
|
||||
return False
|
||||
# Check if inside Keras FuncGraph.
|
||||
if is_in_keras_graph():
|
||||
return False
|
||||
# Check for a v1 `wrap_function` FuncGraph.
|
||||
graph = ops.get_default_graph()
|
||||
if (getattr(graph, 'name', False) and
|
||||
graph.name.startswith('wrapped_function')):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def uses_keras_history(tensors):
|
||||
"""Check if at least one Tensor originates from a `keras.Input`.
|
||||
|
||||
This is `True` if at least one Tensor has its origin in a `keras.Input`.
|
||||
Any Tensor that originates from a `keras.Input` will have a dependency
|
||||
Tensor with a `_keras_history` attribute attached. Tensors that have
|
||||
already been checked to not originate from a `keras.Input`
|
||||
are marked as `_keras_history_checked`.
|
||||
|
||||
Arguments:
|
||||
tensors: An arbitrary nested structure of Tensors.
|
||||
|
||||
Returns:
|
||||
Bool, whether at least one Tensor originates from a `keras.Input`.
|
||||
"""
|
||||
checked_tensors = set()
|
||||
tensors_to_check = nest.flatten(tensors)
|
||||
|
||||
while tensors_to_check:
|
||||
new_tensors_to_check = []
|
||||
for tensor in tensors_to_check:
|
||||
if id(tensor) in checked_tensors:
|
||||
continue
|
||||
|
||||
checked_tensors.add(id(tensor))
|
||||
|
||||
if getattr(tensor, '_keras_history_checked', None) is not None:
|
||||
continue
|
||||
if getattr(tensor, '_keras_history', None) is not None:
|
||||
return True
|
||||
|
||||
try:
|
||||
new_tensors_to_check.extend(tensor.op.inputs)
|
||||
except AttributeError:
|
||||
# In case `tensor` is a Variable created in an Eager context.
|
||||
pass
|
||||
|
||||
tensors_to_check = new_tensors_to_check
|
||||
|
||||
# Mark that these Tensors have been checked once for `_keras_history`,
|
||||
# and should not be checked again for performance reasons.
|
||||
mark_checked(tensors)
|
||||
return False
|
||||
|
||||
|
||||
def mark_checked(tensors):
|
||||
"""Marks that these Tensors should not be tracked.
|
||||
|
||||
This prevents Layers from attempting to create TensorFlowOpLayers
|
||||
for these Tensors.
|
||||
|
||||
Arguments:
|
||||
tensors: An arbitrary structure of Tensors.
|
||||
"""
|
||||
|
||||
def _mark_checked(tensor):
|
||||
tensor._keras_history_checked = True # pylint: disable=protected-access
|
||||
|
||||
nest.map_structure(_mark_checked, tensors)
|
||||
|
||||
|
||||
def call_context():
|
||||
"""Returns currently active `CallContext`."""
|
||||
if getattr(_call_context, 'call_context', None) is None:
|
||||
_call_context.call_context = CallContext()
|
||||
return _call_context.call_context
|
||||
|
||||
|
||||
class CallContext(object):
|
||||
"""Keeps track of properties currently inside a Layer/Model's `call`.
|
||||
|
||||
Attributes:
|
||||
layer: The `Layer` whose `call` is currently active.
|
||||
inputs: The inputs to the currently active `Layer`.
|
||||
frozen: Whether currently executing inside a `Layer` with `trainable` set to
|
||||
`False`.
|
||||
in_call: Whether currently inside the `call` of a Layer.
|
||||
training: Whether currently executing in training or inference mode.
|
||||
in_keras_graph: Whether executing inside the Keras Graph.
|
||||
saving: Whether currently saving to SavedModel.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.layer = None
|
||||
self.inputs = None
|
||||
self.frozen = False
|
||||
self.in_call = False
|
||||
self.training = None
|
||||
self._in_keras_graph = False
|
||||
self.saving = False
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def enter(self, layer, inputs, build_graph, training, saving=None):
|
||||
"""Push a Layer and its inputs and state onto the current call context."""
|
||||
prev_layer = self.layer
|
||||
prev_inputs = self.inputs
|
||||
prev_frozen = self.frozen
|
||||
prev_in_call = self.in_call
|
||||
prev_training = self.training
|
||||
prev_in_keras_graph = self._in_keras_graph
|
||||
prev_saving = self.saving
|
||||
|
||||
self.layer = layer
|
||||
self.inputs = inputs
|
||||
self.frozen = self.frozen or not layer.trainable
|
||||
self.in_call = True
|
||||
self.training = training
|
||||
self._in_keras_graph = (
|
||||
self._in_keras_graph or
|
||||
(build_graph and
|
||||
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
|
||||
self.saving = prev_saving if saving is None else saving
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.layer = prev_layer
|
||||
self.inputs = prev_inputs
|
||||
self.frozen = prev_frozen
|
||||
self.in_call = prev_in_call
|
||||
self.training = prev_training
|
||||
self._in_keras_graph = prev_in_keras_graph
|
||||
self.saving = prev_saving
|
||||
|
||||
@property
|
||||
def in_keras_graph(self):
|
||||
# Returns True even if in a subgraph of the Keras graph, such as those
|
||||
# created by control flow ops.
|
||||
if context.executing_eagerly():
|
||||
return False
|
||||
return (self._in_keras_graph or
|
||||
getattr(backend.get_graph(), 'name', None) == 'keras_graph')
|
||||
|
||||
|
||||
def training_arg_passed_to_call(argspec, args, kwargs):
|
||||
"""Returns whether a user passed the `training` argument in `__call__`."""
|
||||
# `argspec.args` starts with ['self', 'inputs']
|
||||
full_args = dict(zip(argspec.args[2:], args))
|
||||
full_args.update(kwargs)
|
||||
return 'training' in full_args and full_args['training'] is not None
|
||||
|
||||
|
||||
def autocast_context_manager(dtype):
|
||||
"""Returns a context manager to autocast AutoCastVariables.
|
||||
|
||||
Under this context manager, AutoCastVariables will be casted to `dtype` if
|
||||
`dtype` is floating-point. Otherwise, AutoCastVariables will not be casted.
|
||||
|
||||
Args:
|
||||
dtype: The dtype to cast AutoCastVariables to, or None.
|
||||
|
||||
Returns:
|
||||
A context manager to automatically cast AutoCastVariables.
|
||||
"""
|
||||
if dtype and not dtypes.as_dtype(dtype).is_floating:
|
||||
dtype = None
|
||||
return ops.get_default_graph()._enable_auto_casting_variables(dtype) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def is_subclassed(layer):
|
||||
"""Returns True if the object is a subclassed layer or subclassed model."""
|
||||
return (layer.__module__.find('keras.engine') == -1 and
|
||||
layer.__module__.find('keras.layers') == -1)
|
||||
|
||||
|
||||
def from_saved_model(layer):
|
||||
"""Returns whether the layer is loaded from a SavedModel."""
|
||||
return layer.__module__.find('keras.saving.saved_model') != -1
|
||||
|
||||
|
||||
def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
|
||||
"""Checks that tensors passed to `add_*` method match the Keras graph.
|
||||
|
||||
When one of the `add_*` method is called inside a V2 conditional branch,
|
||||
the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
|
||||
We need to raise clear error messages in such cases.
|
||||
|
||||
Arguments:
|
||||
tensor: Tensor to check, or `False` if it is known that an error
|
||||
should be raised.
|
||||
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
|
||||
force_raise: If an error should be raised regardless of `tensor`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: In case of an out-of-graph tensor.
|
||||
"""
|
||||
if (force_raise or
|
||||
(ops.executing_eagerly_outside_functions() and
|
||||
hasattr(tensor, 'graph') and
|
||||
isinstance(tensor.graph,
|
||||
(control_flow_v2_func_graphs.CondBranchFuncGraph,
|
||||
control_flow_v2_func_graphs.WhileCondFuncGraph,
|
||||
control_flow_v2_func_graphs.WhileBodyFuncGraph)))):
|
||||
if method == 'activity_regularizer':
|
||||
bad_example = """
|
||||
class TestModel(tf.keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__(name='test_model')
|
||||
self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
|
||||
|
||||
def call(self, x, training=None):
|
||||
if training:
|
||||
return self.dense(x)
|
||||
else:
|
||||
return self.dense(x)
|
||||
"""
|
||||
correct_example = """
|
||||
class TestModel(tf.keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__(name='test_model')
|
||||
self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
|
||||
|
||||
def call(self, x, training=None):
|
||||
return self.dense(x)
|
||||
"""
|
||||
raise RuntimeError(
|
||||
'You are using a layer with `activity_regularizer` in a control flow '
|
||||
'branch, e.g.:\n{bad_example}\nThis is currently not supported. '
|
||||
'Please move your call to the layer with `activity_regularizer` out '
|
||||
'of the control flow branch, e.g.:\n{correct_example}\n'
|
||||
'You can also resolve this by marking your outer model/layer dynamic'
|
||||
' (eager-only) by passing `dynamic=True` to the layer constructor. '
|
||||
'Any kind of control flow is supported with dynamic layers. '
|
||||
'Note that using `dynamic=True` requires you to implement static '
|
||||
'shape inference in the `compute_output_shape(input_shape)` '
|
||||
'method.'.format(
|
||||
bad_example=bad_example, correct_example=correct_example))
|
||||
|
||||
if method == 'add_metric':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
metric = compute_metric(inputs)
|
||||
else:
|
||||
metric = 0.
|
||||
self.add_metric(metric, name='my_metric', aggregation='mean')
|
||||
return inputs
|
||||
"""
|
||||
elif method == 'add_loss':
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
loss = compute_loss(inputs)
|
||||
else:
|
||||
loss = 0.
|
||||
self.add_loss(loss)
|
||||
return inputs
|
||||
"""
|
||||
else:
|
||||
bad_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
self.add_update(self.w.assign_add(1))
|
||||
return inputs
|
||||
"""
|
||||
correct_example = """
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
increment = 1
|
||||
else:
|
||||
increment = 0
|
||||
self.add_update(self.w.assign_add(increment))
|
||||
return inputs
|
||||
"""
|
||||
raise RuntimeError(
|
||||
'You are using the method `{method}` in a control flow branch '
|
||||
'in your layer, e.g.:\n{bad_example}\n'
|
||||
'This is not currently supported. '
|
||||
'Please move your call to {method} out of the control flow branch, '
|
||||
'e.g.:\n{correct_example}\n'
|
||||
'You can also resolve this by marking your layer '
|
||||
'as dynamic (eager-only) by passing '
|
||||
'`dynamic=True` to the layer constructor. '
|
||||
'Any kind of control flow is supported with dynamic layers. '
|
||||
'Note that using `dynamic=True` requires you '
|
||||
'to implement static shape inference '
|
||||
'in the `compute_output_shape(input_shape)` method.'.format(
|
||||
method=method,
|
||||
bad_example=bad_example,
|
||||
correct_example=correct_example))
|
||||
|
||||
|
||||
def mark_as_return(outputs, acd):
|
||||
"""Marks `outputs` as the return values for automatic control deps."""
|
||||
|
||||
def _mark_as_return(tensor):
|
||||
"""Marks `tensor` as the return value for automatic control deps."""
|
||||
if not tensor_util.is_tensor(tensor):
|
||||
return tensor
|
||||
|
||||
# pylint: disable=protected-access
|
||||
return_tensor = acd.mark_as_return(tensor)
|
||||
if getattr(tensor, '_keras_mask', None) is not None:
|
||||
return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
|
||||
else:
|
||||
return_tensor._keras_mask = None
|
||||
|
||||
# Handle TensorFlow Probability attached metadata.
|
||||
# TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
|
||||
if getattr(tensor, '_tfp_distribution', None) is not None:
|
||||
return_tensor._tfp_distribution = tensor._tfp_distribution
|
||||
|
||||
return return_tensor
|
||||
# pylint: enable=protected-access
|
||||
|
||||
return nest.map_structure(_mark_as_return, outputs)
|
||||
|
||||
|
||||
V2_DTYPE_BEHAVIOR = None
|
||||
|
||||
|
||||
# These two functions are not exported because we plan on removing them in the
|
||||
# future.
|
||||
def enable_v2_dtype_behavior():
|
||||
"""Enable the V2 dtype behavior for Keras layers.
|
||||
|
||||
By default, the V2 dtype behavior is enabled in TensorFlow 2.
|
||||
|
||||
When enabled, the dtype of Keras layers defaults to floatx (which is typically
|
||||
float32) instead of None. In addition, layers will automatically cast
|
||||
floating-point inputs to the layer's dtype.
|
||||
|
||||
For example, once enabled, the following block will run a Conv2D layer
|
||||
in float32:
|
||||
|
||||
```python
|
||||
x = tf.ones((4, 4, 4, 4), dtype='float64')
|
||||
layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
|
||||
print(layer.dtype) # Float32 when enabled. None when disabled.
|
||||
# When enabled, will cast inputs to the layer's dtype, which is float32. When
|
||||
# disabled, will do no casting, so the layer is done in float64.
|
||||
y = layer(x)
|
||||
```
|
||||
|
||||
A layer author can opt-out their layer from the automatic input casting by
|
||||
passing `autocast=False` to the base Layer's constructor. This disables the
|
||||
autocasting part of the V2 behavior for that layer, but not the defaulting to
|
||||
floatx part of the V2 behavior.
|
||||
|
||||
When a global `tf.keras.mixed_precision.experimental.Policy` is set, the
|
||||
layer's dtype will default to the global policy instead of floatx. Layers
|
||||
will automatically cast inputs to the policy's compute_dtype.
|
||||
"""
|
||||
global V2_DTYPE_BEHAVIOR
|
||||
V2_DTYPE_BEHAVIOR = True
|
||||
|
||||
|
||||
def disable_v2_dtype_behavior():
|
||||
"""Disables the V2 dtype behavior for Keras layers.
|
||||
|
||||
See `enable_v2_dtype_behavior`.
|
||||
|
||||
This function will be removed in the future.
|
||||
"""
|
||||
global V2_DTYPE_BEHAVIOR
|
||||
V2_DTYPE_BEHAVIOR = False
|
||||
|
||||
|
||||
def v2_dtype_behavior_enabled():
|
||||
"""Returns True if the V2 dtype behavior is enabled."""
|
||||
if V2_DTYPE_BEHAVIOR is None:
|
||||
return tf2.enabled()
|
||||
return V2_DTYPE_BEHAVIOR
|
||||
|
||||
|
||||
class TrackableWeightHandler(object):
|
||||
"""Keras wrapper for handling tracking.Trackable object saving and restoring.
|
||||
|
||||
This class handles Trackables in both V1 and V2 modes, ensuring that they can
|
||||
be saved and restored with the correct data and without adding additional ops
|
||||
on every save.
|
||||
|
||||
Attributes:
|
||||
trackable: The trackable to wrap.
|
||||
num_tensors: The number of tensors that this trackable requires for saving.
|
||||
"""
|
||||
|
||||
def __init__(self, trackable):
|
||||
if not isinstance(trackable, tracking.Trackable):
|
||||
raise ValueError('%s is not a Trackable object.' % (trackable,))
|
||||
self._trackable = trackable
|
||||
|
||||
# TODO(b/141682913): Figure out why this is private and fix it.
|
||||
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
|
||||
if len(saveables) != 1:
|
||||
raise ValueError('Only Trackables with one Saveable are supported.')
|
||||
saveable = list(saveables)[0]
|
||||
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# If we're in eager mode, we need to defer calling the Trackable's
|
||||
# saveable() callable until data export time.
|
||||
# However, it is safe to call the saveable as many times as we want, so
|
||||
# we will call it now to figure out how many tensors this Trackable will
|
||||
# produce.
|
||||
self._saveable = saveable
|
||||
self._num_tensors = len(self._saveable().specs)
|
||||
self._setter = lambda weights: self._saveable().restore(weights, None)
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
|
||||
else:
|
||||
# If we're in Graph mode, we need to evaluate the Saveable only once and
|
||||
# cache the resulting restore graph. Failing to do this will result in
|
||||
# new assignment ops being added to the graph each time set_weights() is
|
||||
# called.
|
||||
self._placeholder_tensors = []
|
||||
self._saveable = saveable()
|
||||
self._num_tensors = len(self._saveable.specs)
|
||||
for spec in self._saveable.specs:
|
||||
tensor = spec.tensor
|
||||
self._placeholder_tensors.append(
|
||||
array_ops.placeholder(tensor.dtype, tensor.shape))
|
||||
self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
|
||||
self._setter = self._set_weights_v1
|
||||
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
|
||||
|
||||
@property
|
||||
def num_tensors(self):
|
||||
return self._num_tensors
|
||||
|
||||
def set_weights(self, weights):
|
||||
if len(weights) != self._num_tensors:
|
||||
raise ValueError(
|
||||
('Weight handler for trackable %s received the wrong number of ' +
|
||||
'weights: expected %s, got %s.') %
|
||||
(self._trackable, self._num_tensors, len(weights)))
|
||||
self._setter(weights)
|
||||
|
||||
def get_tensors(self):
|
||||
return self._getter()
|
||||
|
||||
def _set_weights_v1(self, weights):
|
||||
feed_dict = {}
|
||||
for idx, tensor in enumerate(weights):
|
||||
feed_dict[self._placeholder_tensors[idx]] = tensor
|
||||
backend.get_session().run(self._assign_op, feed_dict)
|
||||
|
||||
|
||||
# TODO(kathywu): This is a temporary hack. When a network of layers is revived
|
||||
# from SavedModel, only the top-level layer will have losses. This causes issues
|
||||
# in eager mode because the child layers may have graph losses
|
||||
# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
|
||||
# whenever eager losses are added to one layer, add eager losses to all
|
||||
# child layers. This causes `.losses` to only return eager losses.
|
||||
REVIVED_LOSS_PLACEHOLDER = (
|
||||
'This layer\'s losses have been added to the parent layer.')
|
@ -1,71 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.frozen_keras import backend
|
||||
from tensorflow.python.frozen_keras.engine import base_layer_utils
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TrackableWeightHandlerTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def get_table_handler(self):
|
||||
# Note: There is some repetition in these tests' setup. However, Tensorflow
|
||||
# does not play nicely with a separate setUp() call (causing errors related
|
||||
# to graph building), so we have to use a called setup instead of a setUp()
|
||||
# call.
|
||||
table = lookup_ops.MutableHashTable(
|
||||
key_dtype=dtypes.string, value_dtype=dtypes.int32, default_value=0)
|
||||
return base_layer_utils.TrackableWeightHandler(table)
|
||||
|
||||
def test_get_num_tensors(self):
|
||||
table_handler = self.get_table_handler()
|
||||
self.assertEqual(2, table_handler.num_tensors)
|
||||
|
||||
def test_get_and_set_weights(self):
|
||||
table_handler = self.get_table_handler()
|
||||
|
||||
table_data = {b"a": 1, b"b": 2, b"c": 3}
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
weights = backend.batch_get_value(table_handler.get_tensors())
|
||||
weight_data = {key: value for key, value in zip(weights[0], weights[1])}
|
||||
self.assertDictEqual(table_data, weight_data)
|
||||
|
||||
def test_get_and_set_weights_does_not_add_ops(self):
|
||||
table_handler = self.get_table_handler()
|
||||
table_data = {b"a": 1, b"b": 2, b"c": 3}
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
_ = backend.batch_get_value(table_handler.get_tensors())
|
||||
backend.get_session().graph.finalize()
|
||||
table_handler.set_weights(
|
||||
[list(table_data.keys()),
|
||||
list(table_data.values())])
|
||||
_ = backend.batch_get_value(table_handler.get_tensors())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,233 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
"""Contains the InputSpec class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.frozen_keras import backend
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class InputSpec(object):
|
||||
"""Specifies the rank, dtype and shape of every input to a layer.
|
||||
|
||||
Layers can expose (if appropriate) an `input_spec` attribute:
|
||||
an instance of `InputSpec`, or a nested structure of `InputSpec` instances
|
||||
(one per input tensor). These objects enable the layer to run input
|
||||
compatibility checks for input structure, input rank, input shape, and
|
||||
input dtype.
|
||||
|
||||
A None entry in a shape is compatible with any dimension,
|
||||
a None shape is compatible with any shape.
|
||||
|
||||
Arguments:
|
||||
dtype: Expected DataType of the input.
|
||||
shape: Shape tuple, expected shape of the input
|
||||
(may include None for unchecked axes).
|
||||
ndim: Integer, expected rank of the input.
|
||||
max_ndim: Integer, maximum rank of the input.
|
||||
min_ndim: Integer, minimum rank of the input.
|
||||
axes: Dictionary mapping integer axes to
|
||||
a specific dimension value.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dtype=None,
|
||||
shape=None,
|
||||
ndim=None,
|
||||
max_ndim=None,
|
||||
min_ndim=None,
|
||||
axes=None):
|
||||
self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
|
||||
if shape is not None:
|
||||
self.ndim = len(shape)
|
||||
self.shape = shape
|
||||
else:
|
||||
self.ndim = ndim
|
||||
self.shape = None
|
||||
self.max_ndim = max_ndim
|
||||
self.min_ndim = min_ndim
|
||||
try:
|
||||
axes = axes or {}
|
||||
self.axes = {int(k): axes[k] for k in axes}
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError('The keys in axes must be integers.')
|
||||
|
||||
if self.axes and (self.ndim is not None or self.max_ndim is not None):
|
||||
max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
|
||||
max_axis = max(self.axes)
|
||||
if max_axis > max_dim:
|
||||
raise ValueError('Axis {} is greater than the maximum allowed value: {}'
|
||||
.format(max_axis, max_dim))
|
||||
|
||||
def __repr__(self):
|
||||
spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
|
||||
('shape=' + str(self.shape)) if self.shape else '',
|
||||
('ndim=' + str(self.ndim)) if self.ndim else '',
|
||||
('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
|
||||
('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
|
||||
('axes=' + str(self.axes)) if self.axes else '']
|
||||
return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'dtype': self.dtype,
|
||||
'shape': self.shape,
|
||||
'ndim': self.ndim,
|
||||
'max_ndim': self.max_ndim,
|
||||
'min_ndim': self.min_ndim,
|
||||
'axes': self.axes}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
|
||||
def to_tensor_shape(spec):
|
||||
"""Returns a tf.TensorShape object that matches the shape specifications.
|
||||
|
||||
If the InputSpec's shape or ndim is defined, this method will return a fully
|
||||
or partially-known shape. Otherwise, the returned TensorShape is None.
|
||||
|
||||
Args:
|
||||
spec: an InputSpec object.
|
||||
|
||||
Returns:
|
||||
a tf.TensorShape object
|
||||
"""
|
||||
if spec.ndim is None and spec.shape is None:
|
||||
return tensor_shape.TensorShape(None)
|
||||
elif spec.shape is not None:
|
||||
return tensor_shape.TensorShape(spec.shape)
|
||||
else:
|
||||
shape = [None] * spec.ndim
|
||||
for a in spec.axes:
|
||||
shape[a] = spec.axes[a] # Assume that axes is defined
|
||||
return tensor_shape.TensorShape(shape)
|
||||
|
||||
|
||||
def assert_input_compatibility(input_spec, inputs, layer_name):
|
||||
"""Checks compatibility between the layer and provided inputs.
|
||||
|
||||
This checks that the tensor(s) `inputs` verify the input assumptions
|
||||
of a layer (if any). If not, a clear and actional exception gets raised.
|
||||
|
||||
Arguments:
|
||||
input_spec: An InputSpec instance, list of InputSpec instances, a nested
|
||||
structure of InputSpec instances, or None.
|
||||
inputs: Input tensor, list of input tensors, or a nested structure of
|
||||
input tensors.
|
||||
layer_name: String, name of the layer (for error message formatting).
|
||||
|
||||
Raises:
|
||||
ValueError: in case of mismatch between
|
||||
the provided inputs and the expectations of the layer.
|
||||
"""
|
||||
if not input_spec:
|
||||
return
|
||||
|
||||
inputs = nest.flatten(inputs)
|
||||
input_spec = nest.flatten(input_spec)
|
||||
if len(inputs) != len(input_spec):
|
||||
raise ValueError('Layer ' + layer_name + ' expects ' +
|
||||
str(len(input_spec)) + ' inputs, '
|
||||
'but it received ' + str(len(inputs)) +
|
||||
' input tensors. Inputs received: ' + str(inputs))
|
||||
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
|
||||
if spec is None:
|
||||
continue
|
||||
|
||||
if (spec.ndim is not None or
|
||||
spec.min_ndim is not None or
|
||||
spec.max_ndim is not None):
|
||||
if x.shape.ndims is None:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'its rank is undefined, but the layer requires a '
|
||||
'defined rank.')
|
||||
|
||||
# Check ndim.
|
||||
if spec.ndim is not None:
|
||||
ndim = x.shape.ndims
|
||||
if ndim != spec.ndim:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
|
||||
str(ndim) + '. Full shape received: ' +
|
||||
str(x.shape.as_list()))
|
||||
if spec.max_ndim is not None:
|
||||
ndim = x.shape.ndims
|
||||
if ndim is not None and ndim > spec.max_ndim:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'expected max_ndim=' + str(spec.max_ndim) +
|
||||
', found ndim=' + str(ndim))
|
||||
if spec.min_ndim is not None:
|
||||
ndim = x.shape.ndims
|
||||
if ndim is not None and ndim < spec.min_ndim:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
': expected min_ndim=' + str(spec.min_ndim) +
|
||||
', found ndim=' + str(ndim) +
|
||||
'. Full shape received: ' +
|
||||
str(x.shape.as_list()))
|
||||
# Check dtype.
|
||||
if spec.dtype is not None:
|
||||
if x.dtype != spec.dtype:
|
||||
raise ValueError('Input ' + str(input_index) + ' of layer ' +
|
||||
layer_name + ' is incompatible with the layer: '
|
||||
'expected dtype=' + str(spec.dtype) +
|
||||
', found dtype=' + str(x.dtype))
|
||||
# Check specific shape axes.
|
||||
if spec.axes:
|
||||
shape = x.shape.as_list()
|
||||
if shape is not None:
|
||||
for axis, value in spec.axes.items():
|
||||
if hasattr(value, 'value'):
|
||||
value = value.value
|
||||
if value is not None and shape[int(axis)] not in {value, None}:
|
||||
raise ValueError(
|
||||
'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
|
||||
' incompatible with the layer: expected axis ' + str(axis) +
|
||||
' of input shape to have value ' + str(value) +
|
||||
' but received input with shape ' + str(shape))
|
||||
# Check shape.
|
||||
if spec.shape is not None:
|
||||
shape = x.shape.as_list()
|
||||
if shape is not None:
|
||||
for spec_dim, dim in zip(spec.shape, shape):
|
||||
if spec_dim is not None and dim is not None:
|
||||
if spec_dim != dim:
|
||||
raise ValueError('Input ' + str(input_index) +
|
||||
' is incompatible with layer ' + layer_name +
|
||||
': expected shape=' + str(spec.shape) +
|
||||
', found shape=' + str(shape))
|
||||
|
||||
|
||||
def to_tensor_spec(input_spec, default_dtype=None):
|
||||
"""Converts a Keras InputSpec object to a TensorSpec."""
|
||||
default_dtype = default_dtype or backend.floatx()
|
||||
if isinstance(input_spec, InputSpec):
|
||||
dtype = input_spec.dtype or default_dtype
|
||||
return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
|
||||
return tensor_spec.TensorSpec(None, default_dtype)
|
@ -1,66 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""InputSpec tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.frozen_keras.engine import input_spec
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class InputSpecTest(test.TestCase):
|
||||
|
||||
def test_axes_initialization(self):
|
||||
input_spec.InputSpec(shape=[1, None, 2, 3], axes={3: 5, '2': 2})
|
||||
with self.assertRaisesRegexp(ValueError, 'Axis 4 is greater than'):
|
||||
input_spec.InputSpec(shape=[1, None, 2, 3], axes={4: 5})
|
||||
with self.assertRaisesRegexp(TypeError, 'keys in axes must be integers'):
|
||||
input_spec.InputSpec(shape=[1, None, 2, 3], axes={'string': 5})
|
||||
|
||||
|
||||
class InputSpecToTensorShapeTest(test.TestCase):
|
||||
|
||||
def test_defined_shape(self):
|
||||
spec = input_spec.InputSpec(shape=[1, None, 2, 3])
|
||||
self.assertAllEqual(
|
||||
[1, None, 2, 3], input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
def test_defined_ndims(self):
|
||||
spec = input_spec.InputSpec(ndim=5)
|
||||
self.assertAllEqual(
|
||||
[None] * 5, input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
spec = input_spec.InputSpec(ndim=0)
|
||||
self.assertAllEqual(
|
||||
[], input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
spec = input_spec.InputSpec(ndim=3, axes={1: 3, -1: 2})
|
||||
self.assertAllEqual(
|
||||
[None, 3, 2], input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
def test_undefined_shapes(self):
|
||||
spec = input_spec.InputSpec(max_ndim=5)
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
|
||||
input_spec.to_tensor_shape(spec).as_list()
|
||||
|
||||
spec = input_spec.InputSpec(min_ndim=5, max_ndim=5)
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
|
||||
input_spec.to_tensor_shape(spec).as_list()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,190 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
"""Contains the `Node` class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.frozen_keras import backend
|
||||
from tensorflow.python.frozen_keras.engine import base_layer_utils
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class Node(object):
|
||||
"""A `Node` describes the connectivity between two layers.
|
||||
|
||||
Each time a layer is connected to some new input,
|
||||
a node is added to `layer._inbound_nodes`.
|
||||
Each time the output of a layer is used by another layer,
|
||||
a node is added to `layer._outbound_nodes`.
|
||||
|
||||
Arguments:
|
||||
outbound_layer: the layer that takes
|
||||
`input_tensors` and turns them into `output_tensors`
|
||||
(the node gets created when the `call`
|
||||
method of the layer was called).
|
||||
inbound_layers: a list of layers, the same length as `input_tensors`,
|
||||
the layers from where `input_tensors` originate.
|
||||
node_indices: a list of integers, the same length as `inbound_layers`.
|
||||
`node_indices[i]` is the origin node of `input_tensors[i]`
|
||||
(necessary since each inbound layer might have several nodes,
|
||||
e.g. if the layer is being shared with a different data stream).
|
||||
tensor_indices: a list of integers,
|
||||
the same length as `inbound_layers`.
|
||||
`tensor_indices[i]` is the index of `input_tensors[i]` within the
|
||||
output of the inbound layer
|
||||
(necessary since each inbound layer might
|
||||
have multiple tensor outputs, with each one being
|
||||
independently manipulable).
|
||||
input_tensors: list of input tensors.
|
||||
output_tensors: list of output tensors.
|
||||
arguments: dictionary of keyword arguments that were passed to the
|
||||
`call` method of the layer at the call that created the node.
|
||||
|
||||
`node_indices` and `tensor_indices` are basically fine-grained coordinates
|
||||
describing the origin of the `input_tensors`.
|
||||
|
||||
A node from layer A to layer B is added to:
|
||||
- A._outbound_nodes
|
||||
- B._inbound_nodes
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
outbound_layer,
|
||||
inbound_layers,
|
||||
node_indices,
|
||||
tensor_indices,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
arguments=None):
|
||||
# Layer instance (NOT a sequence)
|
||||
if isinstance(outbound_layer, (list, tuple, dict)):
|
||||
raise ValueError('`outbound_layer` should be a layer instance, '
|
||||
'not a list, tuple, or, dict.')
|
||||
|
||||
# this is the layer that takes a nested structure of input tensors
|
||||
# and turns them into a nested structure of output tensors.
|
||||
# the current node will be added to
|
||||
# the inbound_nodes of outbound_layer.
|
||||
self.outbound_layer = outbound_layer
|
||||
|
||||
# The following 3 properties describe where
|
||||
# the input tensors come from: which layers,
|
||||
# and for each layer, which node and which
|
||||
# tensor output of each node.
|
||||
|
||||
# Nested structure of layer instances.
|
||||
self.inbound_layers = inbound_layers
|
||||
# Nested structure of integers, 1:1 mapping with inbound_layers.
|
||||
self.node_indices = node_indices
|
||||
# Nested of integers, 1:1 mapping with inbound_layers.
|
||||
self.tensor_indices = tensor_indices
|
||||
|
||||
# Following 2 properties:
|
||||
# tensor inputs and outputs of outbound_layer.
|
||||
|
||||
# Nested structure of tensors. 1:1 mapping with inbound_layers.
|
||||
self.input_tensors = input_tensors
|
||||
# Nested structure of tensors, created by outbound_layer.call().
|
||||
self.output_tensors = output_tensors
|
||||
|
||||
# Following 2 properties: input and output shapes.
|
||||
|
||||
# Nested structure of shape tuples, shapes of input_tensors.
|
||||
self.input_shapes = nest.map_structure(backend.int_shape, input_tensors)
|
||||
# Nested structure of shape tuples, shapes of output_tensors.
|
||||
self.output_shapes = nest.map_structure(backend.int_shape, output_tensors)
|
||||
|
||||
# Optional keyword arguments to layer's `call`.
|
||||
self.arguments = arguments
|
||||
|
||||
# Create Keras History for any Keras Tensors in `arguments`.
|
||||
tensor_arguments = [
|
||||
t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor)
|
||||
]
|
||||
for tensor_argument in tensor_arguments:
|
||||
if base_layer_utils.needs_keras_history(
|
||||
tensor_argument, ignore_call_context=True):
|
||||
base_layer_utils.create_keras_history(tensor_argument)
|
||||
|
||||
# Add nodes to all layers involved.
|
||||
for layer in nest.flatten(inbound_layers):
|
||||
if layer is not None:
|
||||
# For compatibility with external Keras, we use the deprecated
|
||||
# accessor here.
|
||||
layer.outbound_nodes.append(self)
|
||||
# For compatibility with external Keras, we use the deprecated
|
||||
# accessor here.
|
||||
outbound_layer.inbound_nodes.append(self)
|
||||
|
||||
def iterate_inbound(self, include_arguments=False):
|
||||
"""Returns a list of tuples representing the inbound data.
|
||||
|
||||
Arguments:
|
||||
include_arguments: Whether to also iterate over any Keras Tensors
|
||||
passed as args, kwargs.
|
||||
|
||||
Returns:
|
||||
List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
|
||||
"""
|
||||
inputs_inbound = list(
|
||||
zip(
|
||||
nest.flatten(self.inbound_layers),
|
||||
nest.flatten(self.node_indices),
|
||||
nest.flatten(self.tensor_indices),
|
||||
nest.flatten(self.input_tensors)))
|
||||
|
||||
if include_arguments:
|
||||
keras_tensor_arguments = [
|
||||
kt for kt in nest.flatten(self.arguments)
|
||||
if hasattr(kt, '_keras_history')
|
||||
]
|
||||
|
||||
def _get_inbound(keras_tensor):
|
||||
kh = keras_tensor._keras_history
|
||||
return kh.layer, kh.node_index, kh.tensor_index, keras_tensor
|
||||
|
||||
arguments_inbound = nest.map_structure(_get_inbound,
|
||||
keras_tensor_arguments)
|
||||
|
||||
return inputs_inbound + arguments_inbound
|
||||
else:
|
||||
return inputs_inbound
|
||||
|
||||
def _get_all_node_dependencies(self):
|
||||
"""Returns all of the nodes this node immediately depends on."""
|
||||
node_deps = []
|
||||
for layer, node_index, _, _ in self.iterate_inbound():
|
||||
node_deps.append(layer._inbound_nodes[node_index])
|
||||
|
||||
for arg in nest.flatten(self.arguments):
|
||||
if isinstance(arg, ops.Tensor) and hasattr(arg, '_keras_history'):
|
||||
kh = arg._keras_history
|
||||
node_deps.append(kh.layer._inbound_nodes[kh.node_index])
|
||||
|
||||
return node_deps
|
||||
|
||||
def get_config(self):
|
||||
inbound_names = nest.map_structure(
|
||||
lambda layer: layer.name if layer else None, self.inbound_layers)
|
||||
return {
|
||||
'outbound_layer': self.outbound_layer.name,
|
||||
'inbound_layers': inbound_names,
|
||||
'node_indices': self.node_indices,
|
||||
'tensor_indices': self.tensor_indices
|
||||
}
|
@ -1,198 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras initializer serialization / deserialization."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.frozen_keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.frozen_keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
|
||||
# These imports are brought in so that keras.initializers.deserialize
|
||||
# has them available in module_objects.
|
||||
from tensorflow.python.ops.init_ops import Constant
|
||||
from tensorflow.python.ops.init_ops import GlorotNormal
|
||||
from tensorflow.python.ops.init_ops import GlorotUniform
|
||||
from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import Identity
|
||||
from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import Ones
|
||||
from tensorflow.python.ops.init_ops import Orthogonal
|
||||
from tensorflow.python.ops.init_ops import RandomNormal as TFRandomNormal
|
||||
from tensorflow.python.ops.init_ops import RandomUniform as TFRandomUniform
|
||||
from tensorflow.python.ops.init_ops import TruncatedNormal as TFTruncatedNormal
|
||||
from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import Zeros
|
||||
# pylint: disable=unused-import, disable=line-too-long
|
||||
from tensorflow.python.ops.init_ops_v2 import Constant as ConstantV2
|
||||
from tensorflow.python.ops.init_ops_v2 import GlorotNormal as GlorotNormalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import GlorotUniform as GlorotUniformV2
|
||||
from tensorflow.python.ops.init_ops_v2 import he_normal as he_normalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import he_uniform as he_uniformV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Identity as IdentityV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2
|
||||
from tensorflow.python.ops.init_ops_v2 import lecun_normal as lecun_normalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import lecun_uniform as lecun_uniformV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Ones as OnesV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Orthogonal as OrthogonalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import RandomNormal as RandomNormalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import RandomUniform as RandomUniformV2
|
||||
from tensorflow.python.ops.init_ops_v2 import TruncatedNormal as TruncatedNormalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import VarianceScaling as VarianceScalingV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Zeros as ZerosV2
|
||||
# pylint: enable=unused-import, enable=line-too-long
|
||||
|
||||
|
||||
class TruncatedNormal(TFTruncatedNormal):
|
||||
"""Initializer that generates a truncated normal distribution.
|
||||
|
||||
These values are similar to values from a `random_normal_initializer`
|
||||
except that values more than two standard deviations from the mean
|
||||
are discarded and re-drawn. This is the recommended initializer for
|
||||
neural network weights and filters.
|
||||
|
||||
Args:
|
||||
mean: a python scalar or a scalar tensor. Mean of the random values to
|
||||
generate. Defaults to 0.
|
||||
stddev: a python scalar or a scalar tensor. Standard deviation of the random
|
||||
values to generate. Defaults to 0.05.
|
||||
seed: A Python integer. Used to create random seeds. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
dtype: The data type. Only floating point types are supported.
|
||||
|
||||
Returns:
|
||||
A TruncatedNormal instance.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
|
||||
super(TruncatedNormal, self).__init__(
|
||||
mean=mean, stddev=stddev, seed=seed, dtype=dtype)
|
||||
|
||||
|
||||
class RandomUniform(TFRandomUniform):
|
||||
"""Initializer that generates tensors with a uniform distribution.
|
||||
|
||||
Args:
|
||||
minval: A python scalar or a scalar tensor. Lower bound of the range of
|
||||
random values to generate. Defaults to -0.05.
|
||||
maxval: A python scalar or a scalar tensor. Upper bound of the range of
|
||||
random values to generate. Defaults to 0.05.
|
||||
seed: A Python integer. Used to create random seeds. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
dtype: The data type.
|
||||
|
||||
Returns:
|
||||
A RandomUniform instance.
|
||||
"""
|
||||
|
||||
def __init__(self, minval=-0.05, maxval=0.05, seed=None,
|
||||
dtype=dtypes.float32):
|
||||
super(RandomUniform, self).__init__(
|
||||
minval=minval, maxval=maxval, seed=seed, dtype=dtype)
|
||||
|
||||
|
||||
class RandomNormal(TFRandomNormal):
|
||||
"""Initializer that generates tensors with a normal distribution.
|
||||
|
||||
Args:
|
||||
mean: a python scalar or a scalar tensor. Mean of the random values to
|
||||
generate. Defaults to 0.
|
||||
stddev: a python scalar or a scalar tensor. Standard deviation of the random
|
||||
values to generate. Defaults to 0.05.
|
||||
seed: A Python integer. Used to create random seeds. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
dtype: The data type. Only floating point types are supported.
|
||||
|
||||
Returns:
|
||||
RandomNormal instance.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
|
||||
super(RandomNormal, self).__init__(
|
||||
mean=mean, stddev=stddev, seed=seed, dtype=dtype)
|
||||
|
||||
|
||||
# Compatibility aliases
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
zero = zeros = Zeros
|
||||
one = ones = Ones
|
||||
constant = Constant
|
||||
uniform = random_uniform = RandomUniform
|
||||
normal = random_normal = RandomNormal
|
||||
truncated_normal = TruncatedNormal
|
||||
identity = Identity
|
||||
orthogonal = Orthogonal
|
||||
glorot_normal = GlorotNormal
|
||||
glorot_uniform = GlorotUniform
|
||||
|
||||
|
||||
# Utility functions
|
||||
|
||||
|
||||
def serialize(initializer):
|
||||
return serialize_keras_object(initializer)
|
||||
|
||||
|
||||
def deserialize(config, custom_objects=None):
|
||||
"""Return an `Initializer` object from its config."""
|
||||
if tf2.enabled():
|
||||
# Class names are the same for V1 and V2 but the V2 classes
|
||||
# are aliased in this file so we need to grab them directly
|
||||
# from `init_ops_v2`.
|
||||
module_objects = {
|
||||
obj_name: getattr(init_ops_v2, obj_name)
|
||||
for obj_name in dir(init_ops_v2)
|
||||
}
|
||||
else:
|
||||
module_objects = globals()
|
||||
return deserialize_keras_object(
|
||||
config,
|
||||
module_objects=module_objects,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='initializer')
|
||||
|
||||
|
||||
def get(identifier):
|
||||
if identifier is None:
|
||||
return None
|
||||
if isinstance(identifier, dict):
|
||||
return deserialize(identifier)
|
||||
elif isinstance(identifier, six.string_types):
|
||||
identifier = str(identifier)
|
||||
# We have to special-case functions that return classes.
|
||||
# TODO(omalleyt): Turn these into classes or class aliases.
|
||||
special_cases = ['he_normal', 'he_uniform', 'lecun_normal', 'lecun_uniform']
|
||||
if identifier in special_cases:
|
||||
# Treat like a class.
|
||||
return deserialize({'class_name': identifier, 'config': {}})
|
||||
return deserialize(identifier)
|
||||
elif callable(identifier):
|
||||
return identifier
|
||||
else:
|
||||
raise ValueError('Could not interpret initializer identifier: ' +
|
||||
str(identifier))
|
||||
|
||||
|
||||
# pylint: enable=invalid-name
|
@ -1,266 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Built-in regularizers."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.frozen_keras import backend as K
|
||||
from tensorflow.python.frozen_keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.frozen_keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class Regularizer(object):
|
||||
"""Regularizer base class.
|
||||
|
||||
Regularizers allow you to apply penalties on layer parameters or layer
|
||||
activity during optimization. These penalties are summed into the loss
|
||||
function that the network optimizes.
|
||||
|
||||
Regularization penalties are applied on a per-layer basis. The exact API will
|
||||
depend on the layer, but many layers (e.g. `Dense`, `Conv1D`, `Conv2D` and
|
||||
`Conv3D`) have a unified API.
|
||||
|
||||
These layers expose 3 keyword arguments:
|
||||
|
||||
- `kernel_regularizer`: Regularizer to apply a penalty on the layer's kernel
|
||||
- `bias_regularizer`: Regularizer to apply a penalty on the layer's bias
|
||||
- `activity_regularizer`: Regularizer to apply a penalty on the layer's output
|
||||
|
||||
All layers (including custom layers) expose `activity_regularizer` as a
|
||||
settable property, whether or not it is in the constructor arguments.
|
||||
|
||||
The value returned by the `activity_regularizer` is divided by the input
|
||||
batch size so that the relative weighting between the weight regularizers and
|
||||
the activity regularizers does not change with the batch size.
|
||||
|
||||
You can access a layer's regularization penalties by calling `layer.losses`
|
||||
after calling the layer on inputs.
|
||||
|
||||
## Example
|
||||
|
||||
>>> layer = tf.keras.layers.Dense(
|
||||
... 5, input_dim=5,
|
||||
... kernel_initializer='ones',
|
||||
... kernel_regularizer=tf.keras.regularizers.l1(0.01),
|
||||
... activity_regularizer=tf.keras.regularizers.l2(0.01))
|
||||
>>> tensor = tf.ones(shape=(5, 5)) * 2.0
|
||||
>>> out = layer(tensor)
|
||||
|
||||
>>> # The kernel regularization term is 0.25
|
||||
>>> # The activity regularization term (after dividing by the batch size) is 5
|
||||
>>> tf.math.reduce_sum(layer.losses)
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=5.25>
|
||||
|
||||
## Available penalties
|
||||
|
||||
```python
|
||||
tf.keras.regularizers.l1(0.3) # L1 Regularization Penalty
|
||||
tf.keras.regularizers.l2(0.1) # L2 Regularization Penalty
|
||||
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01) # L1 + L2 penalties
|
||||
```
|
||||
|
||||
## Directly calling a regularizer
|
||||
|
||||
Compute a regularization loss on a tensor by directly calling a regularizer
|
||||
as if it is a one-argument function.
|
||||
|
||||
E.g.
|
||||
>>> regularizer = tf.keras.regularizers.l2(2.)
|
||||
>>> tensor = tf.ones(shape=(5, 5))
|
||||
>>> regularizer(tensor)
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=50.0>
|
||||
|
||||
### A note on serialization and deserialization:
|
||||
|
||||
Registering the regularizers as serializable is optional if you are just
|
||||
training and executing models, exporting to and from SavedModels, or saving
|
||||
and loading weight checkpoints.
|
||||
|
||||
Registration is required for Keras `model_to_estimator`, saving and
|
||||
loading models to HDF5 formats, Keras model cloning, some visualization
|
||||
utilities, and exporting models to and from JSON. If using this functionality,
|
||||
you must make sure any python process running your model has also defined
|
||||
and registered your custom regularizer.
|
||||
|
||||
`tf.keras.utils.register_keras_serializable` is only available in TF 2.1 and
|
||||
beyond. In earlier versions of TensorFlow you must pass your custom
|
||||
regularizer to the `custom_objects` argument of methods that expect custom
|
||||
regularizers to be registered as serializable.
|
||||
"""
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute a regularization penalty from an input tensor."""
|
||||
return 0.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Creates a regularizer from its config.
|
||||
|
||||
This method is the reverse of `get_config`,
|
||||
capable of instantiating the same regularizer from the config
|
||||
dictionary.
|
||||
|
||||
This method is used by Keras `model_to_estimator`, saving and
|
||||
loading models to HDF5 formats, Keras model cloning, some visualization
|
||||
utilities, and exporting models to and from JSON.
|
||||
|
||||
Arguments:
|
||||
config: A Python dictionary, typically the output of get_config.
|
||||
|
||||
Returns:
|
||||
A regularizer instance.
|
||||
"""
|
||||
return cls(**config)
|
||||
|
||||
def get_config(self):
|
||||
"""Returns the config of the regularizer.
|
||||
|
||||
An regularizer config is a Python dictionary (serializable)
|
||||
containing all configuration parameters of the regularizer.
|
||||
The same regularizer can be reinstantiated later
|
||||
(without any saved state) from this configuration.
|
||||
|
||||
This method is optional if you are just training and executing models,
|
||||
exporting to and from SavedModels, or using weight checkpoints.
|
||||
|
||||
This method is required for Keras `model_to_estimator`, saving and
|
||||
loading models to HDF5 formats, Keras model cloning, some visualization
|
||||
utilities, and exporting models to and from JSON.
|
||||
|
||||
Returns:
|
||||
Python dictionary.
|
||||
"""
|
||||
raise NotImplementedError(str(self) + ' does not implement get_config()')
|
||||
|
||||
|
||||
class L1L2(Regularizer):
|
||||
r"""A regularizer that applies both L1 and L2 regularization penalties.
|
||||
|
||||
The L1 regularization penalty is computed as:
|
||||
$$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$
|
||||
|
||||
The L2 regularization penalty is computed as
|
||||
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
|
||||
|
||||
Attributes:
|
||||
l1: Float; L1 regularization factor.
|
||||
l2: Float; L2 regularization factor.
|
||||
"""
|
||||
|
||||
def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name
|
||||
self.l1 = K.cast_to_floatx(l1)
|
||||
self.l2 = K.cast_to_floatx(l2)
|
||||
|
||||
def __call__(self, x):
|
||||
if not self.l1 and not self.l2:
|
||||
return K.constant(0.)
|
||||
regularization = 0.
|
||||
if self.l1:
|
||||
regularization += self.l1 * math_ops.reduce_sum(math_ops.abs(x))
|
||||
if self.l2:
|
||||
regularization += self.l2 * math_ops.reduce_sum(math_ops.square(x))
|
||||
return regularization
|
||||
|
||||
def get_config(self):
|
||||
return {'l1': float(self.l1), 'l2': float(self.l2)}
|
||||
|
||||
|
||||
# Aliases.
|
||||
|
||||
|
||||
def l1(l=0.01):
|
||||
r"""Create a regularizer that applies an L1 regularization penalty.
|
||||
|
||||
The L1 regularization penalty is computed as:
|
||||
$$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$
|
||||
|
||||
Arguments:
|
||||
l: Float; L1 regularization factor.
|
||||
|
||||
Returns:
|
||||
An L1 Regularizer with the given regularization factor.
|
||||
"""
|
||||
return L1L2(l1=l)
|
||||
|
||||
|
||||
def l2(l=0.01):
|
||||
r"""Create a regularizer that applies an L2 regularization penalty.
|
||||
|
||||
The L2 regularization penalty is computed as:
|
||||
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
|
||||
|
||||
Arguments:
|
||||
l: Float; L2 regularization factor.
|
||||
|
||||
Returns:
|
||||
An L2 Regularizer with the given regularization factor.
|
||||
"""
|
||||
return L1L2(l2=l)
|
||||
|
||||
|
||||
def l1_l2(l1=0.01, l2=0.01): # pylint: disable=redefined-outer-name
|
||||
r"""Create a regularizer that applies both L1 and L2 penalties.
|
||||
|
||||
The L1 regularization penalty is computed as:
|
||||
$$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$
|
||||
|
||||
The L2 regularization penalty is computed as:
|
||||
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
|
||||
|
||||
Arguments:
|
||||
l1: Float; L1 regularization factor.
|
||||
l2: Float; L2 regularization factor.
|
||||
|
||||
Returns:
|
||||
An L1L2 Regularizer with the given regularization factors.
|
||||
"""
|
||||
return L1L2(l1=l1, l2=l2)
|
||||
|
||||
|
||||
def serialize(regularizer):
|
||||
return serialize_keras_object(regularizer)
|
||||
|
||||
|
||||
def deserialize(config, custom_objects=None):
|
||||
return deserialize_keras_object(
|
||||
config,
|
||||
module_objects=globals(),
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='regularizer')
|
||||
|
||||
|
||||
def get(identifier):
|
||||
if identifier is None:
|
||||
return None
|
||||
if isinstance(identifier, dict):
|
||||
return deserialize(identifier)
|
||||
elif isinstance(identifier, six.string_types):
|
||||
identifier = str(identifier)
|
||||
# We have to special-case functions that return classes.
|
||||
# TODO(omalleyt): Turn these into classes or class aliases.
|
||||
special_cases = ['l1', 'l2', 'l1_l2']
|
||||
if identifier in special_cases:
|
||||
# Treat like a class.
|
||||
return deserialize({'class_name': identifier, 'config': {}})
|
||||
return deserialize(str(identifier))
|
||||
elif callable(identifier):
|
||||
return identifier
|
||||
else:
|
||||
raise ValueError('Could not interpret regularizer identifier:', identifier)
|
@ -1,106 +0,0 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tf_utils",
|
||||
srcs = ["tf_utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:smart_cond",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/experimental/ops:cardinality",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/frozen_keras:backend",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "conv_utils",
|
||||
srcs = [
|
||||
"conv_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python/frozen_keras:backend",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "generic_utils",
|
||||
srcs = [
|
||||
"generic_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "layer_utils",
|
||||
srcs = [
|
||||
"layer_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":conv_utils",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/frozen_keras:backend",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "generic_utils_test",
|
||||
size = "small",
|
||||
srcs = ["generic_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":generic_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/frozen_keras:regularizers",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tf_utils_test",
|
||||
size = "small",
|
||||
srcs = ["tf_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":tf_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "conv_utils_test",
|
||||
size = "small",
|
||||
srcs = ["conv_utils_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":conv_utils",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
@ -1,482 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities used by convolution layers."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
from six.moves import range # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.frozen_keras import backend
|
||||
|
||||
|
||||
def convert_data_format(data_format, ndim):
|
||||
if data_format == 'channels_last':
|
||||
if ndim == 3:
|
||||
return 'NWC'
|
||||
elif ndim == 4:
|
||||
return 'NHWC'
|
||||
elif ndim == 5:
|
||||
return 'NDHWC'
|
||||
else:
|
||||
raise ValueError('Input rank not supported:', ndim)
|
||||
elif data_format == 'channels_first':
|
||||
if ndim == 3:
|
||||
return 'NCW'
|
||||
elif ndim == 4:
|
||||
return 'NCHW'
|
||||
elif ndim == 5:
|
||||
return 'NCDHW'
|
||||
else:
|
||||
raise ValueError('Input rank not supported:', ndim)
|
||||
else:
|
||||
raise ValueError('Invalid data_format:', data_format)
|
||||
|
||||
|
||||
def normalize_tuple(value, n, name):
|
||||
"""Transforms a single integer or iterable of integers into an integer tuple.
|
||||
|
||||
Arguments:
|
||||
value: The value to validate and convert. Could an int, or any iterable of
|
||||
ints.
|
||||
n: The size of the tuple to be returned.
|
||||
name: The name of the argument being validated, e.g. "strides" or
|
||||
"kernel_size". This is only used to format error messages.
|
||||
|
||||
Returns:
|
||||
A tuple of n integers.
|
||||
|
||||
Raises:
|
||||
ValueError: If something else than an int/long or iterable thereof was
|
||||
passed.
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
return (value,) * n
|
||||
else:
|
||||
try:
|
||||
value_tuple = tuple(value)
|
||||
except TypeError:
|
||||
raise ValueError('The `' + name + '` argument must be a tuple of ' +
|
||||
str(n) + ' integers. Received: ' + str(value))
|
||||
if len(value_tuple) != n:
|
||||
raise ValueError('The `' + name + '` argument must be a tuple of ' +
|
||||
str(n) + ' integers. Received: ' + str(value))
|
||||
for single_value in value_tuple:
|
||||
try:
|
||||
int(single_value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError('The `' + name + '` argument must be a tuple of ' +
|
||||
str(n) + ' integers. Received: ' + str(value) + ' '
|
||||
'including element ' + str(single_value) + ' of type' +
|
||||
' ' + str(type(single_value)))
|
||||
return value_tuple
|
||||
|
||||
|
||||
def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
|
||||
"""Determines output length of a convolution given input length.
|
||||
|
||||
Arguments:
|
||||
input_length: integer.
|
||||
filter_size: integer.
|
||||
padding: one of "same", "valid", "full", "causal"
|
||||
stride: integer.
|
||||
dilation: dilation rate, integer.
|
||||
|
||||
Returns:
|
||||
The output length (integer).
|
||||
"""
|
||||
if input_length is None:
|
||||
return None
|
||||
assert padding in {'same', 'valid', 'full', 'causal'}
|
||||
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
|
||||
if padding in ['same', 'causal']:
|
||||
output_length = input_length
|
||||
elif padding == 'valid':
|
||||
output_length = input_length - dilated_filter_size + 1
|
||||
elif padding == 'full':
|
||||
output_length = input_length + dilated_filter_size - 1
|
||||
return (output_length + stride - 1) // stride
|
||||
|
||||
|
||||
def conv_input_length(output_length, filter_size, padding, stride):
|
||||
"""Determines input length of a convolution given output length.
|
||||
|
||||
Arguments:
|
||||
output_length: integer.
|
||||
filter_size: integer.
|
||||
padding: one of "same", "valid", "full".
|
||||
stride: integer.
|
||||
|
||||
Returns:
|
||||
The input length (integer).
|
||||
"""
|
||||
if output_length is None:
|
||||
return None
|
||||
assert padding in {'same', 'valid', 'full'}
|
||||
if padding == 'same':
|
||||
pad = filter_size // 2
|
||||
elif padding == 'valid':
|
||||
pad = 0
|
||||
elif padding == 'full':
|
||||
pad = filter_size - 1
|
||||
return (output_length - 1) * stride - 2 * pad + filter_size
|
||||
|
||||
|
||||
def deconv_output_length(input_length,
|
||||
filter_size,
|
||||
padding,
|
||||
output_padding=None,
|
||||
stride=0,
|
||||
dilation=1):
|
||||
"""Determines output length of a transposed convolution given input length.
|
||||
|
||||
Arguments:
|
||||
input_length: Integer.
|
||||
filter_size: Integer.
|
||||
padding: one of `"same"`, `"valid"`, `"full"`.
|
||||
output_padding: Integer, amount of padding along the output dimension. Can
|
||||
be set to `None` in which case the output length is inferred.
|
||||
stride: Integer.
|
||||
dilation: Integer.
|
||||
|
||||
Returns:
|
||||
The output length (integer).
|
||||
"""
|
||||
assert padding in {'same', 'valid', 'full'}
|
||||
if input_length is None:
|
||||
return None
|
||||
|
||||
# Get the dilated kernel size
|
||||
filter_size = filter_size + (filter_size - 1) * (dilation - 1)
|
||||
|
||||
# Infer length if output padding is None, else compute the exact length
|
||||
if output_padding is None:
|
||||
if padding == 'valid':
|
||||
length = input_length * stride + max(filter_size - stride, 0)
|
||||
elif padding == 'full':
|
||||
length = input_length * stride - (stride + filter_size - 2)
|
||||
elif padding == 'same':
|
||||
length = input_length * stride
|
||||
|
||||
else:
|
||||
if padding == 'same':
|
||||
pad = filter_size // 2
|
||||
elif padding == 'valid':
|
||||
pad = 0
|
||||
elif padding == 'full':
|
||||
pad = filter_size - 1
|
||||
|
||||
length = ((input_length - 1) * stride + filter_size - 2 * pad +
|
||||
output_padding)
|
||||
return length
|
||||
|
||||
|
||||
def normalize_data_format(value):
|
||||
if value is None:
|
||||
value = backend.image_data_format()
|
||||
data_format = value.lower()
|
||||
if data_format not in {'channels_first', 'channels_last'}:
|
||||
raise ValueError('The `data_format` argument must be one of '
|
||||
'"channels_first", "channels_last". Received: ' +
|
||||
str(value))
|
||||
return data_format
|
||||
|
||||
|
||||
def normalize_padding(value):
|
||||
if isinstance(value, (list, tuple)):
|
||||
return value
|
||||
padding = value.lower()
|
||||
if padding not in {'valid', 'same', 'causal'}:
|
||||
raise ValueError('The `padding` argument must be a list/tuple or one of '
|
||||
'"valid", "same" (or "causal", only for `Conv1D). '
|
||||
'Received: ' + str(padding))
|
||||
return padding
|
||||
|
||||
|
||||
def convert_kernel(kernel):
|
||||
"""Converts a Numpy kernel matrix from Theano format to TensorFlow format.
|
||||
|
||||
Also works reciprocally, since the transformation is its own inverse.
|
||||
|
||||
This is used for converting legacy Theano-saved model files.
|
||||
|
||||
Arguments:
|
||||
kernel: Numpy array (3D, 4D or 5D).
|
||||
|
||||
Returns:
|
||||
The converted kernel.
|
||||
|
||||
Raises:
|
||||
ValueError: in case of invalid kernel shape or invalid data_format.
|
||||
"""
|
||||
kernel = np.asarray(kernel)
|
||||
if not 3 <= kernel.ndim <= 5:
|
||||
raise ValueError('Invalid kernel shape:', kernel.shape)
|
||||
slices = [slice(None, None, -1) for _ in range(kernel.ndim)]
|
||||
no_flip = (slice(None, None), slice(None, None))
|
||||
slices[-2:] = no_flip
|
||||
return np.copy(kernel[slices])
|
||||
|
||||
|
||||
def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
|
||||
"""Compute a mask representing the connectivity of a convolution operation.
|
||||
|
||||
Assume a convolution with given parameters is applied to an input having N
|
||||
spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
|
||||
output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array
|
||||
of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries
|
||||
indicating pairs of input and output locations that are connected by a weight.
|
||||
|
||||
Example:
|
||||
|
||||
>>> input_shape = (4,)
|
||||
>>> kernel_shape = (2,)
|
||||
>>> strides = (1,)
|
||||
>>> padding = "valid"
|
||||
>>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
|
||||
array([[ True, False, False],
|
||||
[ True, True, False],
|
||||
[False, True, True],
|
||||
[False, False, True]])
|
||||
|
||||
where rows and columns correspond to inputs and outputs respectively.
|
||||
|
||||
|
||||
Args:
|
||||
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
|
||||
input.
|
||||
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
|
||||
receptive field.
|
||||
strides: tuple of size N, strides along each spatial dimension.
|
||||
padding: type of padding, string `"same"` or `"valid"`.
|
||||
|
||||
Returns:
|
||||
A boolean 2N-D `np.ndarray` of shape
|
||||
`(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
|
||||
is the spatial shape of the output. `True` entries in the mask represent
|
||||
pairs of input-output locations that are connected by a weight.
|
||||
|
||||
Raises:
|
||||
ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
|
||||
same number of dimensions.
|
||||
NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
|
||||
"""
|
||||
if padding not in {'same', 'valid'}:
|
||||
raise NotImplementedError('Padding type %s not supported. '
|
||||
'Only "valid" and "same" '
|
||||
'are implemented.' % padding)
|
||||
|
||||
in_dims = len(input_shape)
|
||||
if isinstance(kernel_shape, int):
|
||||
kernel_shape = (kernel_shape,) * in_dims
|
||||
if isinstance(strides, int):
|
||||
strides = (strides,) * in_dims
|
||||
|
||||
kernel_dims = len(kernel_shape)
|
||||
stride_dims = len(strides)
|
||||
if kernel_dims != in_dims or stride_dims != in_dims:
|
||||
raise ValueError('Number of strides, input and kernel dimensions must all '
|
||||
'match. Received: %d, %d, %d.' %
|
||||
(stride_dims, in_dims, kernel_dims))
|
||||
|
||||
output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
|
||||
|
||||
mask_shape = input_shape + output_shape
|
||||
mask = np.zeros(mask_shape, np.bool)
|
||||
|
||||
output_axes_ticks = [range(dim) for dim in output_shape]
|
||||
for output_position in itertools.product(*output_axes_ticks):
|
||||
input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
|
||||
output_position, strides, padding)
|
||||
for input_position in itertools.product(*input_axes_ticks):
|
||||
mask[input_position + output_position] = True
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def conv_kernel_idxs(input_shape, kernel_shape, strides, padding, filters_in,
|
||||
filters_out, data_format):
|
||||
"""Yields output-input tuples of indices in a CNN layer.
|
||||
|
||||
The generator iterates over all `(output_idx, input_idx)` tuples, where
|
||||
`output_idx` is an integer index in a flattened tensor representing a single
|
||||
output image of a convolutional layer that is connected (via the layer
|
||||
weights) to the respective single input image at `input_idx`
|
||||
|
||||
Example:
|
||||
|
||||
>>> input_shape = (2, 2)
|
||||
>>> kernel_shape = (2, 1)
|
||||
>>> strides = (1, 1)
|
||||
>>> padding = "valid"
|
||||
>>> filters_in = 1
|
||||
>>> filters_out = 1
|
||||
>>> data_format = "channels_last"
|
||||
>>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding,
|
||||
... filters_in, filters_out, data_format))
|
||||
[(0, 0), (0, 2), (1, 1), (1, 3)]
|
||||
|
||||
Args:
|
||||
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
|
||||
input.
|
||||
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
|
||||
receptive field.
|
||||
strides: tuple of size N, strides along each spatial dimension.
|
||||
padding: type of padding, string `"same"` or `"valid"`.
|
||||
filters_in: `int`, number if filters in the input to the layer.
|
||||
filters_out: `int', number if filters in the output of the layer.
|
||||
data_format: string, "channels_first" or "channels_last".
|
||||
|
||||
Yields:
|
||||
The next tuple `(output_idx, input_idx)`, where
|
||||
`output_idx` is an integer index in a flattened tensor representing a single
|
||||
output image of a convolutional layer that is connected (via the layer
|
||||
weights) to the respective single input image at `input_idx`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `data_format` is neither
|
||||
`"channels_last"` nor `"channels_first"`, or if number of strides, input,
|
||||
and kernel number of dimensions do not match.
|
||||
|
||||
NotImplementedError: if `padding` is neither `"same"` nor `"valid"`.
|
||||
"""
|
||||
if padding not in ('same', 'valid'):
|
||||
raise NotImplementedError('Padding type %s not supported. '
|
||||
'Only "valid" and "same" '
|
||||
'are implemented.' % padding)
|
||||
|
||||
in_dims = len(input_shape)
|
||||
if isinstance(kernel_shape, int):
|
||||
kernel_shape = (kernel_shape,) * in_dims
|
||||
if isinstance(strides, int):
|
||||
strides = (strides,) * in_dims
|
||||
|
||||
kernel_dims = len(kernel_shape)
|
||||
stride_dims = len(strides)
|
||||
if kernel_dims != in_dims or stride_dims != in_dims:
|
||||
raise ValueError('Number of strides, input and kernel dimensions must all '
|
||||
'match. Received: %d, %d, %d.' %
|
||||
(stride_dims, in_dims, kernel_dims))
|
||||
|
||||
output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
|
||||
output_axes_ticks = [range(dim) for dim in output_shape]
|
||||
|
||||
if data_format == 'channels_first':
|
||||
concat_idxs = lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx
|
||||
elif data_format == 'channels_last':
|
||||
concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (filter_idx,)
|
||||
else:
|
||||
raise ValueError('Data format %s not recognized.'
|
||||
'`data_format` must be "channels_first" or '
|
||||
'"channels_last".' % data_format)
|
||||
|
||||
for output_position in itertools.product(*output_axes_ticks):
|
||||
input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
|
||||
output_position, strides, padding)
|
||||
for input_position in itertools.product(*input_axes_ticks):
|
||||
for f_in in range(filters_in):
|
||||
for f_out in range(filters_out):
|
||||
out_idx = np.ravel_multi_index(
|
||||
multi_index=concat_idxs(output_position, f_out),
|
||||
dims=concat_idxs(output_shape, filters_out))
|
||||
in_idx = np.ravel_multi_index(
|
||||
multi_index=concat_idxs(input_position, f_in),
|
||||
dims=concat_idxs(input_shape, filters_in))
|
||||
yield (out_idx, in_idx)
|
||||
|
||||
|
||||
def conv_connected_inputs(input_shape, kernel_shape, output_position, strides,
|
||||
padding):
|
||||
"""Return locations of the input connected to an output position.
|
||||
|
||||
Assume a convolution with given parameters is applied to an input having N
|
||||
spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
|
||||
returns N ranges specifying the input region that was convolved with the
|
||||
kernel to produce the output at position
|
||||
`output_position = (p_out1, ..., p_outN)`.
|
||||
|
||||
Example:
|
||||
|
||||
>>> input_shape = (4, 4)
|
||||
>>> kernel_shape = (2, 1)
|
||||
>>> output_position = (1, 1)
|
||||
>>> strides = (1, 1)
|
||||
>>> padding = "valid"
|
||||
>>> conv_connected_inputs(input_shape, kernel_shape, output_position,
|
||||
... strides, padding)
|
||||
[range(1, 3), range(1, 2)]
|
||||
|
||||
Args:
|
||||
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
|
||||
input.
|
||||
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
|
||||
receptive field.
|
||||
output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single position
|
||||
in the output of the convolution.
|
||||
strides: tuple of size N, strides along each spatial dimension.
|
||||
padding: type of padding, string `"same"` or `"valid"`.
|
||||
|
||||
Returns:
|
||||
N ranges `[[p_in_left1, ..., p_in_right1], ...,
|
||||
[p_in_leftN, ..., p_in_rightN]]` specifying the region in the
|
||||
input connected to output_position.
|
||||
"""
|
||||
ranges = []
|
||||
|
||||
ndims = len(input_shape)
|
||||
for d in range(ndims):
|
||||
left_shift = int(kernel_shape[d] / 2)
|
||||
right_shift = kernel_shape[d] - left_shift
|
||||
|
||||
center = output_position[d] * strides[d]
|
||||
|
||||
if padding == 'valid':
|
||||
center += left_shift
|
||||
|
||||
start = max(0, center - left_shift)
|
||||
end = min(input_shape[d], center + right_shift)
|
||||
|
||||
ranges.append(range(start, end))
|
||||
|
||||
return ranges
|
||||
|
||||
|
||||
def conv_output_shape(input_shape, kernel_shape, strides, padding):
|
||||
"""Return the output shape of an N-D convolution.
|
||||
|
||||
Forces dimensions where input is empty (size 0) to remain empty.
|
||||
|
||||
Args:
|
||||
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
|
||||
input.
|
||||
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
|
||||
receptive field.
|
||||
strides: tuple of size N, strides along each spatial dimension.
|
||||
padding: type of padding, string `"same"` or `"valid"`.
|
||||
|
||||
Returns:
|
||||
tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
|
||||
"""
|
||||
dims = range(len(kernel_shape))
|
||||
output_shape = [
|
||||
conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d])
|
||||
for d in dims
|
||||
]
|
||||
output_shape = tuple(
|
||||
[0 if input_shape[d] == 0 else output_shape[d] for d in dims])
|
||||
return output_shape
|
@ -1,340 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for conv_utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.frozen_keras.utils import conv_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _get_const_output_shape(input_shape, dim):
|
||||
return tuple([min(d, dim) for d in input_shape])
|
||||
|
||||
|
||||
input_shapes = [
|
||||
(0,),
|
||||
(0, 0),
|
||||
(1,),
|
||||
(2,),
|
||||
(3,),
|
||||
(1, 0),
|
||||
(0, 3),
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(3, 1),
|
||||
(2, 2),
|
||||
(3, 3),
|
||||
(1, 0, 1),
|
||||
(5, 2, 3),
|
||||
(3, 5, 6, 7, 0),
|
||||
(3, 2, 2, 4, 4),
|
||||
(1, 2, 3, 4, 7, 2),
|
||||
]
|
||||
|
||||
|
||||
class TestBasicConvUtilsTest(test.TestCase):
|
||||
|
||||
def test_convert_data_format(self):
|
||||
self.assertEqual('NCDHW', conv_utils.convert_data_format(
|
||||
'channels_first', 5))
|
||||
self.assertEqual('NCHW', conv_utils.convert_data_format(
|
||||
'channels_first', 4))
|
||||
self.assertEqual('NCW', conv_utils.convert_data_format('channels_first', 3))
|
||||
self.assertEqual('NHWC', conv_utils.convert_data_format('channels_last', 4))
|
||||
self.assertEqual('NWC', conv_utils.convert_data_format('channels_last', 3))
|
||||
self.assertEqual('NDHWC', conv_utils.convert_data_format(
|
||||
'channels_last', 5))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
conv_utils.convert_data_format('invalid', 2)
|
||||
|
||||
def test_normalize_tuple(self):
|
||||
self.assertEqual((2, 2, 2),
|
||||
conv_utils.normalize_tuple(2, n=3, name='strides'))
|
||||
self.assertEqual((2, 1, 2),
|
||||
conv_utils.normalize_tuple((2, 1, 2), n=3, name='strides'))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
conv_utils.normalize_tuple((2, 1), n=3, name='strides')
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
conv_utils.normalize_tuple(None, n=3, name='strides')
|
||||
|
||||
def test_normalize_data_format(self):
|
||||
self.assertEqual('channels_last',
|
||||
conv_utils.normalize_data_format('Channels_Last'))
|
||||
self.assertEqual('channels_first',
|
||||
conv_utils.normalize_data_format('CHANNELS_FIRST'))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
conv_utils.normalize_data_format('invalid')
|
||||
|
||||
def test_normalize_padding(self):
|
||||
self.assertEqual('same', conv_utils.normalize_padding('SAME'))
|
||||
self.assertEqual('valid', conv_utils.normalize_padding('VALID'))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
conv_utils.normalize_padding('invalid')
|
||||
|
||||
def test_conv_output_length(self):
|
||||
self.assertEqual(4, conv_utils.conv_output_length(4, 2, 'same', 1, 1))
|
||||
self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'same', 2, 1))
|
||||
self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'valid', 1, 1))
|
||||
self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'valid', 2, 1))
|
||||
self.assertEqual(5, conv_utils.conv_output_length(4, 2, 'full', 1, 1))
|
||||
self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'full', 2, 1))
|
||||
self.assertEqual(2, conv_utils.conv_output_length(5, 2, 'valid', 2, 2))
|
||||
|
||||
def test_conv_input_length(self):
|
||||
self.assertEqual(3, conv_utils.conv_input_length(4, 2, 'same', 1))
|
||||
self.assertEqual(2, conv_utils.conv_input_length(2, 2, 'same', 2))
|
||||
self.assertEqual(4, conv_utils.conv_input_length(3, 2, 'valid', 1))
|
||||
self.assertEqual(4, conv_utils.conv_input_length(2, 2, 'valid', 2))
|
||||
self.assertEqual(3, conv_utils.conv_input_length(4, 2, 'full', 1))
|
||||
self.assertEqual(4, conv_utils.conv_input_length(3, 2, 'full', 2))
|
||||
|
||||
def test_deconv_output_length(self):
|
||||
self.assertEqual(4, conv_utils.deconv_output_length(4, 2, 'same', stride=1))
|
||||
self.assertEqual(8, conv_utils.deconv_output_length(4, 2, 'same', stride=2))
|
||||
self.assertEqual(5, conv_utils.deconv_output_length(
|
||||
4, 2, 'valid', stride=1))
|
||||
self.assertEqual(8, conv_utils.deconv_output_length(
|
||||
4, 2, 'valid', stride=2))
|
||||
self.assertEqual(3, conv_utils.deconv_output_length(4, 2, 'full', stride=1))
|
||||
self.assertEqual(6, conv_utils.deconv_output_length(4, 2, 'full', stride=2))
|
||||
self.assertEqual(
|
||||
5,
|
||||
conv_utils.deconv_output_length(
|
||||
4, 2, 'same', output_padding=2, stride=1))
|
||||
self.assertEqual(
|
||||
7,
|
||||
conv_utils.deconv_output_length(
|
||||
4, 2, 'same', output_padding=1, stride=2))
|
||||
self.assertEqual(
|
||||
7,
|
||||
conv_utils.deconv_output_length(
|
||||
4, 2, 'valid', output_padding=2, stride=1))
|
||||
self.assertEqual(
|
||||
9,
|
||||
conv_utils.deconv_output_length(
|
||||
4, 2, 'valid', output_padding=1, stride=2))
|
||||
self.assertEqual(
|
||||
5,
|
||||
conv_utils.deconv_output_length(
|
||||
4, 2, 'full', output_padding=2, stride=1))
|
||||
self.assertEqual(
|
||||
7,
|
||||
conv_utils.deconv_output_length(
|
||||
4, 2, 'full', output_padding=1, stride=2))
|
||||
self.assertEqual(
|
||||
5,
|
||||
conv_utils.deconv_output_length(
|
||||
4, 2, 'same', output_padding=1, stride=1, dilation=2))
|
||||
self.assertEqual(
|
||||
12,
|
||||
conv_utils.deconv_output_length(
|
||||
4, 2, 'valid', output_padding=2, stride=2, dilation=3))
|
||||
self.assertEqual(
|
||||
6,
|
||||
conv_utils.deconv_output_length(
|
||||
4, 2, 'full', output_padding=2, stride=2, dilation=3))
|
||||
|
||||
|
||||
@parameterized.parameters(input_shapes)
|
||||
class TestConvUtils(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_conv_kernel_mask_fc(self, *input_shape):
|
||||
padding = 'valid'
|
||||
kernel_shape = input_shape
|
||||
ndims = len(input_shape)
|
||||
strides = (1,) * ndims
|
||||
output_shape = _get_const_output_shape(input_shape, dim=1)
|
||||
mask = np.ones(input_shape + output_shape, np.bool)
|
||||
self.assertAllEqual(
|
||||
mask,
|
||||
conv_utils.conv_kernel_mask(
|
||||
input_shape,
|
||||
kernel_shape,
|
||||
strides,
|
||||
padding
|
||||
)
|
||||
)
|
||||
|
||||
def test_conv_kernel_mask_diag(self, *input_shape):
|
||||
ndims = len(input_shape)
|
||||
kernel_shape = (1,) * ndims
|
||||
strides = (1,) * ndims
|
||||
|
||||
for padding in ['valid', 'same']:
|
||||
mask = np.identity(int(np.prod(input_shape)), np.bool)
|
||||
mask = np.reshape(mask, input_shape * 2)
|
||||
self.assertAllEqual(
|
||||
mask,
|
||||
conv_utils.conv_kernel_mask(
|
||||
input_shape,
|
||||
kernel_shape,
|
||||
strides,
|
||||
padding
|
||||
)
|
||||
)
|
||||
|
||||
def test_conv_kernel_mask_full_stride(self, *input_shape):
|
||||
padding = 'valid'
|
||||
ndims = len(input_shape)
|
||||
kernel_shape = (1,) * ndims
|
||||
strides = tuple([max(d, 1) for d in input_shape])
|
||||
output_shape = _get_const_output_shape(input_shape, dim=1)
|
||||
|
||||
mask = np.zeros(input_shape + output_shape, np.bool)
|
||||
if all(d > 0 for d in mask.shape):
|
||||
mask[(0,) * len(output_shape)] = True
|
||||
|
||||
self.assertAllEqual(
|
||||
mask,
|
||||
conv_utils.conv_kernel_mask(
|
||||
input_shape,
|
||||
kernel_shape,
|
||||
strides,
|
||||
padding
|
||||
)
|
||||
)
|
||||
|
||||
def test_conv_kernel_mask_almost_full_stride(self, *input_shape):
|
||||
padding = 'valid'
|
||||
ndims = len(input_shape)
|
||||
kernel_shape = (1,) * ndims
|
||||
strides = tuple([max(d - 1, 1) for d in input_shape])
|
||||
output_shape = _get_const_output_shape(input_shape, dim=2)
|
||||
|
||||
mask = np.zeros(input_shape + output_shape, np.bool)
|
||||
if all(d > 0 for d in mask.shape):
|
||||
for in_position in itertools.product(*[[0, d - 1] for d in input_shape]):
|
||||
out_position = tuple([min(p, 1) for p in in_position])
|
||||
mask[in_position + out_position] = True
|
||||
|
||||
self.assertAllEqual(
|
||||
mask,
|
||||
conv_utils.conv_kernel_mask(
|
||||
input_shape,
|
||||
kernel_shape,
|
||||
strides,
|
||||
padding
|
||||
)
|
||||
)
|
||||
|
||||
def test_conv_kernel_mask_rect_kernel(self, *input_shape):
|
||||
padding = 'valid'
|
||||
ndims = len(input_shape)
|
||||
strides = (1,) * ndims
|
||||
|
||||
for d in range(ndims):
|
||||
kernel_shape = [1] * ndims
|
||||
kernel_shape[d] = input_shape[d]
|
||||
|
||||
output_shape = list(input_shape)
|
||||
output_shape[d] = min(1, input_shape[d])
|
||||
|
||||
mask = np.identity(int(np.prod(input_shape)), np.bool)
|
||||
mask = np.reshape(mask, input_shape * 2)
|
||||
|
||||
for p in itertools.product(*[range(input_shape[dim])
|
||||
for dim in range(ndims)]):
|
||||
p = list(p)
|
||||
p[d] = slice(None)
|
||||
mask[p * 2] = True
|
||||
|
||||
mask = np.take(mask, range(0, min(1, input_shape[d])), ndims + d)
|
||||
|
||||
self.assertAllEqual(
|
||||
mask,
|
||||
conv_utils.conv_kernel_mask(
|
||||
input_shape,
|
||||
kernel_shape,
|
||||
strides,
|
||||
padding
|
||||
)
|
||||
)
|
||||
|
||||
def test_conv_kernel_mask_wrong_padding(self, *input_shape):
|
||||
ndims = len(input_shape)
|
||||
kernel_shape = (1,) * ndims
|
||||
strides = (1,) * ndims
|
||||
|
||||
conv_utils.conv_kernel_mask(
|
||||
input_shape,
|
||||
kernel_shape,
|
||||
strides,
|
||||
'valid'
|
||||
)
|
||||
|
||||
conv_utils.conv_kernel_mask(
|
||||
input_shape,
|
||||
kernel_shape,
|
||||
strides,
|
||||
'same'
|
||||
)
|
||||
|
||||
self.assertRaises(NotImplementedError,
|
||||
conv_utils.conv_kernel_mask,
|
||||
input_shape, kernel_shape, strides, 'full')
|
||||
|
||||
def test_conv_kernel_mask_wrong_dims(self, *input_shape):
|
||||
kernel_shape = 1
|
||||
strides = 1
|
||||
|
||||
conv_utils.conv_kernel_mask(
|
||||
input_shape,
|
||||
kernel_shape,
|
||||
strides,
|
||||
'valid'
|
||||
)
|
||||
|
||||
ndims = len(input_shape)
|
||||
|
||||
kernel_shape = (2,) * (ndims + 1)
|
||||
self.assertRaises(ValueError,
|
||||
conv_utils.conv_kernel_mask,
|
||||
input_shape, kernel_shape, strides, 'same')
|
||||
|
||||
strides = (1,) * ndims
|
||||
self.assertRaises(ValueError,
|
||||
conv_utils.conv_kernel_mask,
|
||||
input_shape, kernel_shape, strides, 'valid')
|
||||
|
||||
kernel_shape = (1,) * ndims
|
||||
strides = (2,) * (ndims - 1)
|
||||
self.assertRaises(ValueError,
|
||||
conv_utils.conv_kernel_mask,
|
||||
input_shape, kernel_shape, strides, 'valid')
|
||||
|
||||
strides = (2,) * ndims
|
||||
conv_utils.conv_kernel_mask(
|
||||
input_shape,
|
||||
kernel_shape,
|
||||
strides,
|
||||
'valid'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,612 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python utilities required by Keras."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import binascii
|
||||
import codecs
|
||||
import marshal
|
||||
import os
|
||||
import re
|
||||
import types as python_types
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
_GLOBAL_CUSTOM_OBJECTS = {}
|
||||
_GLOBAL_CUSTOM_NAMES = {}
|
||||
|
||||
# Flag that determines whether to skip the NotImplementedError when calling
|
||||
# get_config in custom models and layers. This is only enabled when saving to
|
||||
# SavedModel, when the config isn't required.
|
||||
_SKIP_FAILED_SERIALIZATION = False
|
||||
# If a layer does not have a defined config, then the returned config will be a
|
||||
# dictionary with the below key.
|
||||
_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
|
||||
|
||||
|
||||
class CustomObjectScope(object):
|
||||
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
|
||||
|
||||
Code within a `with` statement will be able to access custom objects
|
||||
by name. Changes to global custom objects persist
|
||||
within the enclosing `with` statement. At end of the `with` statement,
|
||||
global custom objects are reverted to state
|
||||
at beginning of the `with` statement.
|
||||
|
||||
Example:
|
||||
|
||||
Consider a custom object `MyObject` (e.g. a class):
|
||||
|
||||
```python
|
||||
with CustomObjectScope({'MyObject':MyObject}):
|
||||
layer = Dense(..., kernel_regularizer='MyObject')
|
||||
# save, load, etc. will recognize custom object by name
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
self.custom_objects = args
|
||||
self.backup = None
|
||||
|
||||
def __enter__(self):
|
||||
self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
|
||||
for objects in self.custom_objects:
|
||||
_GLOBAL_CUSTOM_OBJECTS.update(objects)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
_GLOBAL_CUSTOM_OBJECTS.clear()
|
||||
_GLOBAL_CUSTOM_OBJECTS.update(self.backup)
|
||||
|
||||
|
||||
def custom_object_scope(*args):
|
||||
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
|
||||
|
||||
Convenience wrapper for `CustomObjectScope`.
|
||||
Code within a `with` statement will be able to access custom objects
|
||||
by name. Changes to global custom objects persist
|
||||
within the enclosing `with` statement. At end of the `with` statement,
|
||||
global custom objects are reverted to state
|
||||
at beginning of the `with` statement.
|
||||
|
||||
Example:
|
||||
|
||||
Consider a custom object `MyObject`
|
||||
|
||||
```python
|
||||
with custom_object_scope({'MyObject':MyObject}):
|
||||
layer = Dense(..., kernel_regularizer='MyObject')
|
||||
# save, load, etc. will recognize custom object by name
|
||||
```
|
||||
|
||||
Arguments:
|
||||
*args: Variable length list of dictionaries of name, class pairs to add to
|
||||
custom objects.
|
||||
|
||||
Returns:
|
||||
Object of type `CustomObjectScope`.
|
||||
"""
|
||||
return CustomObjectScope(*args)
|
||||
|
||||
|
||||
def get_custom_objects():
|
||||
"""Retrieves a live reference to the global dictionary of custom objects.
|
||||
|
||||
Updating and clearing custom objects using `custom_object_scope`
|
||||
is preferred, but `get_custom_objects` can
|
||||
be used to directly access `_GLOBAL_CUSTOM_OBJECTS`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
get_custom_objects().clear()
|
||||
get_custom_objects()['MyObject'] = MyObject
|
||||
```
|
||||
|
||||
Returns:
|
||||
Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
|
||||
"""
|
||||
return _GLOBAL_CUSTOM_OBJECTS
|
||||
|
||||
|
||||
def serialize_keras_class_and_config(cls_name, cls_config):
|
||||
"""Returns the serialization of the class with the given config."""
|
||||
return {'class_name': cls_name, 'config': cls_config}
|
||||
|
||||
|
||||
def register_keras_serializable(package='Custom', name=None):
|
||||
"""Registers an object with the Keras serialization framework.
|
||||
|
||||
This decorator injects the decorated class or function into the Keras custom
|
||||
object dictionary, so that it can be serialized and deserialized without
|
||||
needing an entry in the user-provided custom object dict. It also injects a
|
||||
function that Keras will call to get the object's serializable string key.
|
||||
|
||||
Note that to be serialized and deserialized, classes must implement the
|
||||
`get_config()` method. Functions do not have this requirement.
|
||||
|
||||
The object will be registered under the key 'package>name' where `name`,
|
||||
defaults to the object name if not passed.
|
||||
|
||||
Arguments:
|
||||
package: The package that this class belongs to.
|
||||
name: The name to serialize this class under in this package. If None, the
|
||||
class's name will be used.
|
||||
|
||||
Returns:
|
||||
A decorator that registers the decorated class with the passed names.
|
||||
"""
|
||||
|
||||
def decorator(arg):
|
||||
"""Registers a class with the Keras serialization framework."""
|
||||
class_name = name if name is not None else arg.__name__
|
||||
registered_name = package + '>' + class_name
|
||||
|
||||
if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'):
|
||||
raise ValueError(
|
||||
'Cannot register a class that does not have a get_config() method.')
|
||||
|
||||
if registered_name in _GLOBAL_CUSTOM_OBJECTS:
|
||||
raise ValueError(
|
||||
'%s has already been registered to %s' %
|
||||
(registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name]))
|
||||
|
||||
if arg in _GLOBAL_CUSTOM_NAMES:
|
||||
raise ValueError('%s has already been registered to %s' %
|
||||
(arg, _GLOBAL_CUSTOM_NAMES[arg]))
|
||||
_GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
|
||||
_GLOBAL_CUSTOM_NAMES[arg] = registered_name
|
||||
|
||||
return arg
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_registered_name(obj):
|
||||
"""Returns the name registered to an object within the Keras framework.
|
||||
|
||||
This function is part of the Keras serialization and deserialization
|
||||
framework. It maps objects to the string names associated with those objects
|
||||
for serialization/deserialization.
|
||||
|
||||
Args:
|
||||
obj: The object to look up.
|
||||
|
||||
Returns:
|
||||
The name associated with the object, or the default Python name if the
|
||||
object is not registered.
|
||||
"""
|
||||
if obj in _GLOBAL_CUSTOM_NAMES:
|
||||
return _GLOBAL_CUSTOM_NAMES[obj]
|
||||
else:
|
||||
return obj.__name__
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def skip_failed_serialization():
|
||||
global _SKIP_FAILED_SERIALIZATION
|
||||
prev = _SKIP_FAILED_SERIALIZATION
|
||||
try:
|
||||
_SKIP_FAILED_SERIALIZATION = True
|
||||
yield
|
||||
finally:
|
||||
_SKIP_FAILED_SERIALIZATION = prev
|
||||
|
||||
|
||||
def get_registered_object(name, custom_objects=None, module_objects=None):
|
||||
"""Returns the class associated with `name` if it is registered with Keras.
|
||||
|
||||
This function is part of the Keras serialization and deserialization
|
||||
framework. It maps strings to the objects associated with them for
|
||||
serialization/deserialization.
|
||||
|
||||
Example:
|
||||
```
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
if 'my_custom_object_name' in config:
|
||||
config['hidden_cls'] = tf.keras.utils.get_registered_object(
|
||||
config['my_custom_object_name'], custom_objects=custom_objects)
|
||||
```
|
||||
|
||||
Args:
|
||||
name: The name to look up.
|
||||
custom_objects: A dictionary of custom objects to look the name up in.
|
||||
Generally, custom_objects is provided by the user.
|
||||
module_objects: A dictionary of custom objects to look the name up in.
|
||||
Generally, module_objects is provided by midlevel library implementers.
|
||||
|
||||
Returns:
|
||||
An instantiable class associated with 'name', or None if no such class
|
||||
exists.
|
||||
"""
|
||||
if name in _GLOBAL_CUSTOM_OBJECTS:
|
||||
return _GLOBAL_CUSTOM_OBJECTS[name]
|
||||
elif custom_objects and name in custom_objects:
|
||||
return custom_objects[name]
|
||||
elif module_objects and name in module_objects:
|
||||
return module_objects[name]
|
||||
return None
|
||||
|
||||
|
||||
def serialize_keras_object(instance):
|
||||
"""Serialize Keras object into JSON."""
|
||||
_, instance = tf_decorator.unwrap(instance)
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
if hasattr(instance, 'get_config'):
|
||||
name = get_registered_name(instance.__class__)
|
||||
try:
|
||||
config = instance.get_config()
|
||||
except NotImplementedError as e:
|
||||
if _SKIP_FAILED_SERIALIZATION:
|
||||
return serialize_keras_class_and_config(
|
||||
name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
|
||||
raise e
|
||||
serialization_config = {}
|
||||
for key, item in config.items():
|
||||
if isinstance(item, six.string_types):
|
||||
serialization_config[key] = item
|
||||
continue
|
||||
|
||||
# Any object of a different type needs to be converted to string or dict
|
||||
# for serialization (e.g. custom functions, custom classes)
|
||||
try:
|
||||
serialized_item = serialize_keras_object(item)
|
||||
if isinstance(serialized_item, dict) and not isinstance(item, dict):
|
||||
serialized_item['__passive_serialization__'] = True
|
||||
serialization_config[key] = serialized_item
|
||||
except ValueError:
|
||||
serialization_config[key] = item
|
||||
|
||||
name = get_registered_name(instance.__class__)
|
||||
return serialize_keras_class_and_config(name, serialization_config)
|
||||
if hasattr(instance, '__name__'):
|
||||
return get_registered_name(instance)
|
||||
raise ValueError('Cannot serialize', instance)
|
||||
|
||||
|
||||
def get_custom_objects_by_name(item, custom_objects=None):
|
||||
"""Returns the item if it is in either local or global custom objects."""
|
||||
if item in _GLOBAL_CUSTOM_OBJECTS:
|
||||
return _GLOBAL_CUSTOM_OBJECTS[item]
|
||||
elif custom_objects and item in custom_objects:
|
||||
return custom_objects[item]
|
||||
return None
|
||||
|
||||
|
||||
def class_and_config_for_serialized_keras_object(
|
||||
config,
|
||||
module_objects=None,
|
||||
custom_objects=None,
|
||||
printable_module_name='object'):
|
||||
"""Returns the class name and config for a serialized keras object."""
|
||||
if (not isinstance(config, dict) or 'class_name' not in config or
|
||||
'config' not in config):
|
||||
raise ValueError('Improper config format: ' + str(config))
|
||||
|
||||
class_name = config['class_name']
|
||||
cls = get_registered_object(class_name, custom_objects, module_objects)
|
||||
if cls is None:
|
||||
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||
|
||||
cls_config = config['config']
|
||||
deserialized_objects = {}
|
||||
for key, item in cls_config.items():
|
||||
if isinstance(item, dict) and '__passive_serialization__' in item:
|
||||
deserialized_objects[key] = deserialize_keras_object(
|
||||
item,
|
||||
module_objects=module_objects,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='config_item')
|
||||
# TODO(momernick): Should this also have 'module_objects'?
|
||||
elif (isinstance(item, six.string_types) and
|
||||
tf_inspect.isfunction(get_registered_object(item, custom_objects))):
|
||||
# Handle custom functions here. When saving functions, we only save the
|
||||
# function's name as a string. If we find a matching string in the custom
|
||||
# objects during deserialization, we convert the string back to the
|
||||
# original function.
|
||||
# Note that a potential issue is that a string field could have a naming
|
||||
# conflict with a custom function name, but this should be a rare case.
|
||||
# This issue does not occur if a string field has a naming conflict with
|
||||
# a custom object, since the config of an object will always be a dict.
|
||||
deserialized_objects[key] = get_registered_object(item, custom_objects)
|
||||
for key, item in deserialized_objects.items():
|
||||
cls_config[key] = deserialized_objects[key]
|
||||
|
||||
return (cls, cls_config)
|
||||
|
||||
|
||||
def deserialize_keras_object(identifier,
|
||||
module_objects=None,
|
||||
custom_objects=None,
|
||||
printable_module_name='object'):
|
||||
if identifier is None:
|
||||
return None
|
||||
|
||||
if isinstance(identifier, dict):
|
||||
# In this case we are dealing with a Keras config dictionary.
|
||||
config = identifier
|
||||
(cls, cls_config) = class_and_config_for_serialized_keras_object(
|
||||
config, module_objects, custom_objects, printable_module_name)
|
||||
|
||||
if hasattr(cls, 'from_config'):
|
||||
arg_spec = tf_inspect.getfullargspec(cls.from_config)
|
||||
custom_objects = custom_objects or {}
|
||||
|
||||
if 'custom_objects' in arg_spec.args:
|
||||
return cls.from_config(
|
||||
cls_config,
|
||||
custom_objects=dict(
|
||||
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
|
||||
list(custom_objects.items())))
|
||||
with CustomObjectScope(custom_objects):
|
||||
return cls.from_config(cls_config)
|
||||
else:
|
||||
# Then `cls` may be a function returning a class.
|
||||
# in this case by convention `config` holds
|
||||
# the kwargs of the function.
|
||||
custom_objects = custom_objects or {}
|
||||
with CustomObjectScope(custom_objects):
|
||||
return cls(**cls_config)
|
||||
elif isinstance(identifier, six.string_types):
|
||||
object_name = identifier
|
||||
if custom_objects and object_name in custom_objects:
|
||||
obj = custom_objects.get(object_name)
|
||||
elif object_name in _GLOBAL_CUSTOM_OBJECTS:
|
||||
obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
|
||||
else:
|
||||
obj = module_objects.get(object_name)
|
||||
if obj is None:
|
||||
raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
|
||||
# Classes passed by name are instantiated with no args, functions are
|
||||
# returned as-is.
|
||||
if tf_inspect.isclass(obj):
|
||||
return obj()
|
||||
return obj
|
||||
elif tf_inspect.isfunction(identifier):
|
||||
# If a function has already been deserialized, return as is.
|
||||
return identifier
|
||||
else:
|
||||
raise ValueError('Could not interpret serialized %s: %s' %
|
||||
(printable_module_name, identifier))
|
||||
|
||||
|
||||
def func_dump(func):
|
||||
"""Serializes a user defined function.
|
||||
|
||||
Arguments:
|
||||
func: the function to serialize.
|
||||
|
||||
Returns:
|
||||
A tuple `(code, defaults, closure)`.
|
||||
"""
|
||||
if os.name == 'nt':
|
||||
raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/')
|
||||
code = codecs.encode(raw_code, 'base64').decode('ascii')
|
||||
else:
|
||||
raw_code = marshal.dumps(func.__code__)
|
||||
code = codecs.encode(raw_code, 'base64').decode('ascii')
|
||||
defaults = func.__defaults__
|
||||
if func.__closure__:
|
||||
closure = tuple(c.cell_contents for c in func.__closure__)
|
||||
else:
|
||||
closure = None
|
||||
return code, defaults, closure
|
||||
|
||||
|
||||
def func_load(code, defaults=None, closure=None, globs=None):
|
||||
"""Deserializes a user defined function.
|
||||
|
||||
Arguments:
|
||||
code: bytecode of the function.
|
||||
defaults: defaults of the function.
|
||||
closure: closure of the function.
|
||||
globs: dictionary of global objects.
|
||||
|
||||
Returns:
|
||||
A function object.
|
||||
"""
|
||||
if isinstance(code, (tuple, list)): # unpack previous dump
|
||||
code, defaults, closure = code
|
||||
if isinstance(defaults, list):
|
||||
defaults = tuple(defaults)
|
||||
|
||||
def ensure_value_to_cell(value):
|
||||
"""Ensures that a value is converted to a python cell object.
|
||||
|
||||
Arguments:
|
||||
value: Any value that needs to be casted to the cell type
|
||||
|
||||
Returns:
|
||||
A value wrapped as a cell object (see function "func_load")
|
||||
"""
|
||||
|
||||
def dummy_fn():
|
||||
# pylint: disable=pointless-statement
|
||||
value # just access it so it gets captured in .__closure__
|
||||
|
||||
cell_value = dummy_fn.__closure__[0]
|
||||
if not isinstance(value, type(cell_value)):
|
||||
return cell_value
|
||||
return value
|
||||
|
||||
if closure is not None:
|
||||
closure = tuple(ensure_value_to_cell(_) for _ in closure)
|
||||
try:
|
||||
raw_code = codecs.decode(code.encode('ascii'), 'base64')
|
||||
except (UnicodeEncodeError, binascii.Error):
|
||||
raw_code = code.encode('raw_unicode_escape')
|
||||
code = marshal.loads(raw_code)
|
||||
if globs is None:
|
||||
globs = globals()
|
||||
return python_types.FunctionType(
|
||||
code, globs, name=code.co_name, argdefs=defaults, closure=closure)
|
||||
|
||||
|
||||
def has_arg(fn, name, accept_all=False):
|
||||
"""Checks if a callable accepts a given keyword argument.
|
||||
|
||||
Arguments:
|
||||
fn: Callable to inspect.
|
||||
name: Check if `fn` can be called with `name` as a keyword argument.
|
||||
accept_all: What to return if there is no parameter called `name` but the
|
||||
function accepts a `**kwargs` argument.
|
||||
|
||||
Returns:
|
||||
bool, whether `fn` accepts a `name` keyword argument.
|
||||
"""
|
||||
arg_spec = tf_inspect.getfullargspec(fn)
|
||||
if accept_all and arg_spec.varkw is not None:
|
||||
return True
|
||||
return name in arg_spec.args
|
||||
|
||||
|
||||
def make_batches(size, batch_size):
|
||||
"""Returns a list of batch indices (tuples of indices).
|
||||
|
||||
Arguments:
|
||||
size: Integer, total size of the data to slice into batches.
|
||||
batch_size: Integer, batch size.
|
||||
|
||||
Returns:
|
||||
A list of tuples of array indices.
|
||||
"""
|
||||
num_batches = int(np.ceil(size / float(batch_size)))
|
||||
return [(i * batch_size, min(size, (i + 1) * batch_size))
|
||||
for i in range(0, num_batches)]
|
||||
|
||||
|
||||
def slice_arrays(arrays, start=None, stop=None):
|
||||
"""Slice an array or list of arrays.
|
||||
|
||||
This takes an array-like, or a list of
|
||||
array-likes, and outputs:
|
||||
- arrays[start:stop] if `arrays` is an array-like
|
||||
- [x[start:stop] for x in arrays] if `arrays` is a list
|
||||
|
||||
Can also work on list/array of indices: `slice_arrays(x, indices)`
|
||||
|
||||
Arguments:
|
||||
arrays: Single array or list of arrays.
|
||||
start: can be an integer index (start index) or a list/array of indices
|
||||
stop: integer (stop index); should be None if `start` was a list.
|
||||
|
||||
Returns:
|
||||
A slice of the array(s).
|
||||
|
||||
Raises:
|
||||
ValueError: If the value of start is a list and stop is not None.
|
||||
"""
|
||||
if arrays is None:
|
||||
return [None]
|
||||
if isinstance(start, list) and stop is not None:
|
||||
raise ValueError('The stop argument has to be None if the value of start '
|
||||
'is a list.')
|
||||
elif isinstance(arrays, list):
|
||||
if hasattr(start, '__len__'):
|
||||
# hdf5 datasets only support list objects as indices
|
||||
if hasattr(start, 'shape'):
|
||||
start = start.tolist()
|
||||
return [None if x is None else x[start] for x in arrays]
|
||||
return [
|
||||
None if x is None else
|
||||
None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays
|
||||
]
|
||||
else:
|
||||
if hasattr(start, '__len__'):
|
||||
if hasattr(start, 'shape'):
|
||||
start = start.tolist()
|
||||
return arrays[start]
|
||||
if hasattr(start, '__getitem__'):
|
||||
return arrays[start:stop]
|
||||
return [None]
|
||||
|
||||
|
||||
def to_list(x):
|
||||
"""Normalizes a list/tensor into a list.
|
||||
|
||||
If a tensor is passed, we return
|
||||
a list of size 1 containing the tensor.
|
||||
|
||||
Arguments:
|
||||
x: target object to be normalized.
|
||||
|
||||
Returns:
|
||||
A list.
|
||||
"""
|
||||
if isinstance(x, list):
|
||||
return x
|
||||
return [x]
|
||||
|
||||
|
||||
def to_snake_case(name):
|
||||
intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
|
||||
insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
|
||||
# If the class is private the name starts with "_" which is not secure
|
||||
# for creating scopes. We prefix the name with "private" in this case.
|
||||
if insecure[0] != '_':
|
||||
return insecure
|
||||
return 'private' + insecure
|
||||
|
||||
|
||||
def is_all_none(structure):
|
||||
iterable = nest.flatten(structure)
|
||||
# We cannot use Python's `any` because the iterable may return Tensors.
|
||||
for element in iterable:
|
||||
if element is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_for_unexpected_keys(name, input_dict, expected_values):
|
||||
unknown = set(input_dict.keys()).difference(expected_values)
|
||||
if unknown:
|
||||
raise ValueError('Unknown entries in {} dictionary: {}. Only expected '
|
||||
'following keys: {}'.format(name, list(unknown),
|
||||
expected_values))
|
||||
|
||||
|
||||
def validate_kwargs(kwargs,
|
||||
allowed_kwargs,
|
||||
error_message='Keyword argument not understood:'):
|
||||
"""Checks that all keyword arguments are in the set of allowed keys."""
|
||||
for kwarg in kwargs:
|
||||
if kwarg not in allowed_kwargs:
|
||||
raise TypeError(error_message, kwarg)
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
"""Determines whether config appears to be a valid layer config."""
|
||||
return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
|
||||
|
||||
|
||||
def default(method):
|
||||
"""Decorates a method to detect overrides in subclasses."""
|
||||
method._is_default = True # pylint: disable=protected-access
|
||||
return method
|
||||
|
||||
|
||||
def is_default(method):
|
||||
"""Check if a method is decorated with the `default` wrapper."""
|
||||
return getattr(method, '_is_default', False)
|
@ -1,321 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Keras generic Python utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.frozen_keras import regularizers
|
||||
from tensorflow.python.frozen_keras.utils import generic_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class HasArgTest(test.TestCase):
|
||||
|
||||
def test_has_arg(self):
|
||||
|
||||
def f_x(x):
|
||||
return x
|
||||
|
||||
def f_x_args(x, *args):
|
||||
_ = args
|
||||
return x
|
||||
|
||||
def f_x_kwargs(x, **kwargs):
|
||||
_ = kwargs
|
||||
return x
|
||||
|
||||
self.assertTrue(generic_utils.has_arg(
|
||||
f_x, 'x', accept_all=False))
|
||||
self.assertFalse(generic_utils.has_arg(
|
||||
f_x, 'y', accept_all=False))
|
||||
self.assertTrue(generic_utils.has_arg(
|
||||
f_x_args, 'x', accept_all=False))
|
||||
self.assertFalse(generic_utils.has_arg(
|
||||
f_x_args, 'y', accept_all=False))
|
||||
self.assertTrue(generic_utils.has_arg(
|
||||
f_x_kwargs, 'x', accept_all=False))
|
||||
self.assertFalse(generic_utils.has_arg(
|
||||
f_x_kwargs, 'y', accept_all=False))
|
||||
self.assertTrue(generic_utils.has_arg(
|
||||
f_x_kwargs, 'y', accept_all=True))
|
||||
|
||||
|
||||
class TestCustomObjectScope(test.TestCase):
|
||||
|
||||
def test_custom_object_scope(self):
|
||||
|
||||
def custom_fn():
|
||||
pass
|
||||
|
||||
class CustomClass(object):
|
||||
pass
|
||||
|
||||
with generic_utils.custom_object_scope(
|
||||
{'CustomClass': CustomClass, 'custom_fn': custom_fn}):
|
||||
# Disable activation test since its not under frozen_keras package.
|
||||
# act = keras.activations.get('custom_fn')
|
||||
# self.assertEqual(act, custom_fn)
|
||||
cl = regularizers.get('CustomClass')
|
||||
self.assertEqual(cl.__class__, CustomClass)
|
||||
|
||||
|
||||
class SerializeKerasObjectTest(test.TestCase):
|
||||
|
||||
def test_serialize_none(self):
|
||||
serialized = generic_utils.serialize_keras_object(None)
|
||||
self.assertEqual(serialized, None)
|
||||
deserialized = generic_utils.deserialize_keras_object(
|
||||
serialized)
|
||||
self.assertEqual(deserialized, None)
|
||||
|
||||
def test_serialize_custom_class_with_default_name(self):
|
||||
|
||||
@generic_utils.register_keras_serializable()
|
||||
class TestClass(object):
|
||||
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def get_config(self):
|
||||
return {'value': self._value}
|
||||
|
||||
serialized_name = 'Custom>TestClass'
|
||||
inst = TestClass(value=10)
|
||||
class_name = generic_utils._GLOBAL_CUSTOM_NAMES[TestClass]
|
||||
self.assertEqual(serialized_name, class_name)
|
||||
config = generic_utils.serialize_keras_object(inst)
|
||||
self.assertEqual(class_name, config['class_name'])
|
||||
new_inst = generic_utils.deserialize_keras_object(config)
|
||||
self.assertIsNot(inst, new_inst)
|
||||
self.assertIsInstance(new_inst, TestClass)
|
||||
self.assertEqual(10, new_inst._value)
|
||||
|
||||
# Make sure registering a new class with same name will fail.
|
||||
with self.assertRaisesRegex(ValueError, '.*has already been registered.*'):
|
||||
@generic_utils.register_keras_serializable() # pylint: disable=function-redefined
|
||||
class TestClass(object):
|
||||
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def get_config(self):
|
||||
return {'value': self._value}
|
||||
|
||||
def test_serialize_custom_class_with_custom_name(self):
|
||||
|
||||
@generic_utils.register_keras_serializable(
|
||||
'TestPackage', 'CustomName')
|
||||
class OtherTestClass(object):
|
||||
|
||||
def __init__(self, val):
|
||||
self._val = val
|
||||
|
||||
def get_config(self):
|
||||
return {'val': self._val}
|
||||
|
||||
serialized_name = 'TestPackage>CustomName'
|
||||
inst = OtherTestClass(val=5)
|
||||
class_name = generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
|
||||
self.assertEqual(serialized_name, class_name)
|
||||
fn_class_name = generic_utils.get_registered_name(
|
||||
OtherTestClass)
|
||||
self.assertEqual(fn_class_name, class_name)
|
||||
|
||||
cls = generic_utils.get_registered_object(fn_class_name)
|
||||
self.assertEqual(OtherTestClass, cls)
|
||||
|
||||
config = generic_utils.serialize_keras_object(inst)
|
||||
self.assertEqual(class_name, config['class_name'])
|
||||
new_inst = generic_utils.deserialize_keras_object(config)
|
||||
self.assertIsNot(inst, new_inst)
|
||||
self.assertIsInstance(new_inst, OtherTestClass)
|
||||
self.assertEqual(5, new_inst._val)
|
||||
|
||||
def test_serialize_custom_function(self):
|
||||
|
||||
@generic_utils.register_keras_serializable()
|
||||
def my_fn():
|
||||
return 42
|
||||
|
||||
serialized_name = 'Custom>my_fn'
|
||||
class_name = generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
|
||||
self.assertEqual(serialized_name, class_name)
|
||||
fn_class_name = generic_utils.get_registered_name(my_fn)
|
||||
self.assertEqual(fn_class_name, class_name)
|
||||
|
||||
config = generic_utils.serialize_keras_object(my_fn)
|
||||
self.assertEqual(class_name, config)
|
||||
fn = generic_utils.deserialize_keras_object(config)
|
||||
self.assertEqual(42, fn())
|
||||
|
||||
fn_2 = generic_utils.get_registered_object(fn_class_name)
|
||||
self.assertEqual(42, fn_2())
|
||||
|
||||
def test_serialize_custom_class_without_get_config_fails(self):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Cannot register a class that does '
|
||||
'not have a get_config.*'):
|
||||
|
||||
@generic_utils.register_keras_serializable( # pylint: disable=unused-variable
|
||||
'TestPackage', 'TestClass')
|
||||
class TestClass(object):
|
||||
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def test_serializable_object(self):
|
||||
|
||||
class SerializableInt(int):
|
||||
"""A serializable object to pass out of a test layer's config."""
|
||||
|
||||
def __new__(cls, value):
|
||||
return int.__new__(cls, value)
|
||||
|
||||
def get_config(self):
|
||||
return {'value': int(self)}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
layer = keras.layers.Dense(
|
||||
SerializableInt(3),
|
||||
activation='relu',
|
||||
kernel_initializer='ones',
|
||||
bias_regularizer='l2')
|
||||
config = keras.layers.serialize(layer)
|
||||
new_layer = keras.layers.deserialize(
|
||||
config, custom_objects={'SerializableInt': SerializableInt})
|
||||
self.assertEqual(new_layer.activation, keras.activations.relu)
|
||||
self.assertEqual(new_layer.bias_regularizer.__class__,
|
||||
keras.regularizers.L1L2)
|
||||
self.assertEqual(new_layer.units.__class__, SerializableInt)
|
||||
self.assertEqual(new_layer.units, 3)
|
||||
|
||||
def test_nested_serializable_object(self):
|
||||
class SerializableInt(int):
|
||||
"""A serializable object to pass out of a test layer's config."""
|
||||
|
||||
def __new__(cls, value):
|
||||
return int.__new__(cls, value)
|
||||
|
||||
def get_config(self):
|
||||
return {'value': int(self)}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
class SerializableNestedInt(int):
|
||||
"""A serializable object containing another serializable object."""
|
||||
|
||||
def __new__(cls, value, int_obj):
|
||||
obj = int.__new__(cls, value)
|
||||
obj.int_obj = int_obj
|
||||
return obj
|
||||
|
||||
def get_config(self):
|
||||
return {'value': int(self), 'int_obj': self.int_obj}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
nested_int = SerializableInt(4)
|
||||
layer = keras.layers.Dense(
|
||||
SerializableNestedInt(3, nested_int),
|
||||
name='SerializableNestedInt',
|
||||
activation='relu',
|
||||
kernel_initializer='ones',
|
||||
bias_regularizer='l2')
|
||||
config = keras.layers.serialize(layer)
|
||||
new_layer = keras.layers.deserialize(
|
||||
config,
|
||||
custom_objects={
|
||||
'SerializableInt': SerializableInt,
|
||||
'SerializableNestedInt': SerializableNestedInt
|
||||
})
|
||||
# Make sure the string field doesn't get convert to custom object, even
|
||||
# they have same value.
|
||||
self.assertEqual(new_layer.name, 'SerializableNestedInt')
|
||||
self.assertEqual(new_layer.activation, keras.activations.relu)
|
||||
self.assertEqual(new_layer.bias_regularizer.__class__,
|
||||
keras.regularizers.L1L2)
|
||||
self.assertEqual(new_layer.units.__class__, SerializableNestedInt)
|
||||
self.assertEqual(new_layer.units, 3)
|
||||
self.assertEqual(new_layer.units.int_obj.__class__, SerializableInt)
|
||||
self.assertEqual(new_layer.units.int_obj, 4)
|
||||
|
||||
def test_nested_serializable_fn(self):
|
||||
|
||||
def serializable_fn(x):
|
||||
"""A serializable function to pass out of a test layer's config."""
|
||||
return x
|
||||
|
||||
class SerializableNestedInt(int):
|
||||
"""A serializable object containing a serializable function."""
|
||||
|
||||
def __new__(cls, value, fn):
|
||||
obj = int.__new__(cls, value)
|
||||
obj.fn = fn
|
||||
return obj
|
||||
|
||||
def get_config(self):
|
||||
return {'value': int(self), 'fn': self.fn}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
layer = keras.layers.Dense(
|
||||
SerializableNestedInt(3, serializable_fn),
|
||||
activation='relu',
|
||||
kernel_initializer='ones',
|
||||
bias_regularizer='l2')
|
||||
config = keras.layers.serialize(layer)
|
||||
new_layer = keras.layers.deserialize(
|
||||
config,
|
||||
custom_objects={
|
||||
'serializable_fn': serializable_fn,
|
||||
'SerializableNestedInt': SerializableNestedInt
|
||||
})
|
||||
self.assertEqual(new_layer.activation, keras.activations.relu)
|
||||
self.assertIsInstance(new_layer.bias_regularizer, keras.regularizers.L1L2)
|
||||
self.assertIsInstance(new_layer.units, SerializableNestedInt)
|
||||
self.assertEqual(new_layer.units, 3)
|
||||
self.assertIs(new_layer.units.fn, serializable_fn)
|
||||
|
||||
|
||||
class SliceArraysTest(test.TestCase):
|
||||
|
||||
def test_slice_arrays(self):
|
||||
input_a = list([1, 2, 3])
|
||||
self.assertEqual(
|
||||
generic_utils.slice_arrays(input_a, start=0),
|
||||
[None, None, None])
|
||||
self.assertEqual(
|
||||
generic_utils.slice_arrays(input_a, stop=3),
|
||||
[None, None, None])
|
||||
self.assertEqual(
|
||||
generic_utils.slice_arrays(input_a, start=0, stop=1),
|
||||
[None, None, None])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,403 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
"""Utilities related to layer/model functionality.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.frozen_keras import backend as K
|
||||
from tensorflow.python.frozen_keras.utils.conv_utils import convert_kernel
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
|
||||
|
||||
def get_source_inputs(tensor, layer=None, node_index=None):
|
||||
"""Returns the list of input tensors necessary to compute `tensor`.
|
||||
|
||||
Output will always be a list of tensors
|
||||
(potentially with 1 element).
|
||||
|
||||
Arguments:
|
||||
tensor: The tensor to start from.
|
||||
layer: Origin layer of the tensor. Will be
|
||||
determined via tensor._keras_history if not provided.
|
||||
node_index: Origin node index of the tensor.
|
||||
|
||||
Returns:
|
||||
List of input tensors.
|
||||
"""
|
||||
if not hasattr(tensor, '_keras_history'):
|
||||
return tensor
|
||||
|
||||
if layer is None or node_index:
|
||||
layer, node_index, _ = tensor._keras_history
|
||||
if not layer._inbound_nodes:
|
||||
return [tensor]
|
||||
else:
|
||||
node = layer._inbound_nodes[node_index]
|
||||
if not node.inbound_layers:
|
||||
# Reached an Input layer, stop recursion.
|
||||
return nest.flatten(node.input_tensors)
|
||||
else:
|
||||
source_tensors = []
|
||||
for layer, node_index, _, tensor in node.iterate_inbound():
|
||||
previous_sources = get_source_inputs(tensor, layer, node_index)
|
||||
# Avoid input redundancy.
|
||||
for x in previous_sources:
|
||||
if all(x is not t for t in source_tensors):
|
||||
source_tensors.append(x)
|
||||
return source_tensors
|
||||
|
||||
|
||||
def validate_string_arg(input_data,
|
||||
allowable_strings,
|
||||
layer_name,
|
||||
arg_name,
|
||||
allow_none=False,
|
||||
allow_callables=False):
|
||||
"""Validates the correctness of a string-based arg."""
|
||||
if allow_none and input_data is None:
|
||||
return
|
||||
elif allow_callables and callable(input_data):
|
||||
return
|
||||
elif isinstance(input_data,
|
||||
six.string_types) and input_data in allowable_strings:
|
||||
return
|
||||
else:
|
||||
allowed_args = '`None`, ' if allow_none else ''
|
||||
allowed_args += 'a `Callable`, ' if allow_callables else ''
|
||||
allowed_args += 'or one of the following values: %s' % (allowable_strings,)
|
||||
raise ValueError(("%s's %s arg received an invalid value %s. " +
|
||||
'Allowed values are %s.') %
|
||||
(layer_name, arg_name, input_data, allowed_args))
|
||||
|
||||
|
||||
def count_params(weights):
|
||||
"""Count the total number of scalars composing the weights.
|
||||
|
||||
Arguments:
|
||||
weights: An iterable containing the weights on which to compute params
|
||||
|
||||
Returns:
|
||||
The total number of scalars composing the weights
|
||||
"""
|
||||
unique_weights = object_identity.ObjectIdentitySet(weights)
|
||||
weight_shapes = [w.shape.as_list() for w in unique_weights]
|
||||
standardized_weight_shapes = [
|
||||
[0 if w_i is None else w_i for w_i in w] for w in weight_shapes
|
||||
]
|
||||
return int(sum(np.prod(p) for p in standardized_weight_shapes))
|
||||
|
||||
|
||||
def print_summary(model, line_length=None, positions=None, print_fn=None):
|
||||
"""Prints a summary of a model.
|
||||
|
||||
Arguments:
|
||||
model: Keras model instance.
|
||||
line_length: Total length of printed lines
|
||||
(e.g. set this to adapt the display to different
|
||||
terminal window sizes).
|
||||
positions: Relative or absolute positions of log elements in each line.
|
||||
If not provided, defaults to `[.33, .55, .67, 1.]`.
|
||||
print_fn: Print function to use.
|
||||
It will be called on each line of the summary.
|
||||
You can set it to a custom function
|
||||
in order to capture the string summary.
|
||||
It defaults to `print` (prints to stdout).
|
||||
"""
|
||||
if print_fn is None:
|
||||
print_fn = print
|
||||
|
||||
if model.__class__.__name__ == 'Sequential':
|
||||
sequential_like = True
|
||||
elif not model._is_graph_network:
|
||||
# We treat subclassed models as a simple sequence of layers, for logging
|
||||
# purposes.
|
||||
sequential_like = True
|
||||
else:
|
||||
sequential_like = True
|
||||
nodes_by_depth = model._nodes_by_depth.values()
|
||||
nodes = []
|
||||
for v in nodes_by_depth:
|
||||
if (len(v) > 1) or (len(v) == 1 and
|
||||
len(nest.flatten(v[0].inbound_layers)) > 1):
|
||||
# if the model has multiple nodes
|
||||
# or if the nodes have multiple inbound_layers
|
||||
# the model is no longer sequential
|
||||
sequential_like = False
|
||||
break
|
||||
nodes += v
|
||||
if sequential_like:
|
||||
# search for shared layers
|
||||
for layer in model.layers:
|
||||
flag = False
|
||||
for node in layer._inbound_nodes:
|
||||
if node in nodes:
|
||||
if flag:
|
||||
sequential_like = False
|
||||
break
|
||||
else:
|
||||
flag = True
|
||||
if not sequential_like:
|
||||
break
|
||||
|
||||
if sequential_like:
|
||||
line_length = line_length or 65
|
||||
positions = positions or [.45, .85, 1.]
|
||||
if positions[-1] <= 1:
|
||||
positions = [int(line_length * p) for p in positions]
|
||||
# header names for the different log elements
|
||||
to_display = ['Layer (type)', 'Output Shape', 'Param #']
|
||||
else:
|
||||
line_length = line_length or 98
|
||||
positions = positions or [.33, .55, .67, 1.]
|
||||
if positions[-1] <= 1:
|
||||
positions = [int(line_length * p) for p in positions]
|
||||
# header names for the different log elements
|
||||
to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
|
||||
relevant_nodes = []
|
||||
for v in model._nodes_by_depth.values():
|
||||
relevant_nodes += v
|
||||
|
||||
def print_row(fields, positions):
|
||||
line = ''
|
||||
for i in range(len(fields)):
|
||||
if i > 0:
|
||||
line = line[:-1] + ' '
|
||||
line += str(fields[i])
|
||||
line = line[:positions[i]]
|
||||
line += ' ' * (positions[i] - len(line))
|
||||
print_fn(line)
|
||||
|
||||
print_fn('Model: "{}"'.format(model.name))
|
||||
print_fn('_' * line_length)
|
||||
print_row(to_display, positions)
|
||||
print_fn('=' * line_length)
|
||||
|
||||
def print_layer_summary(layer):
|
||||
"""Prints a summary for a single layer.
|
||||
|
||||
Arguments:
|
||||
layer: target layer.
|
||||
"""
|
||||
try:
|
||||
output_shape = layer.output_shape
|
||||
except AttributeError:
|
||||
output_shape = 'multiple'
|
||||
except RuntimeError: # output_shape unknown in Eager mode.
|
||||
output_shape = '?'
|
||||
name = layer.name
|
||||
cls_name = layer.__class__.__name__
|
||||
fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()]
|
||||
print_row(fields, positions)
|
||||
|
||||
def print_layer_summary_with_connections(layer):
|
||||
"""Prints a summary for a single layer (including topological connections).
|
||||
|
||||
Arguments:
|
||||
layer: target layer.
|
||||
"""
|
||||
try:
|
||||
output_shape = layer.output_shape
|
||||
except AttributeError:
|
||||
output_shape = 'multiple'
|
||||
connections = []
|
||||
for node in layer._inbound_nodes:
|
||||
if relevant_nodes and node not in relevant_nodes:
|
||||
# node is not part of the current network
|
||||
continue
|
||||
|
||||
for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
|
||||
connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
|
||||
tensor_index))
|
||||
|
||||
name = layer.name
|
||||
cls_name = layer.__class__.__name__
|
||||
if not connections:
|
||||
first_connection = ''
|
||||
else:
|
||||
first_connection = connections[0]
|
||||
fields = [
|
||||
name + ' (' + cls_name + ')', output_shape,
|
||||
layer.count_params(), first_connection
|
||||
]
|
||||
print_row(fields, positions)
|
||||
if len(connections) > 1:
|
||||
for i in range(1, len(connections)):
|
||||
fields = ['', '', '', connections[i]]
|
||||
print_row(fields, positions)
|
||||
|
||||
layers = model.layers
|
||||
for i in range(len(layers)):
|
||||
if sequential_like:
|
||||
print_layer_summary(layers[i])
|
||||
else:
|
||||
print_layer_summary_with_connections(layers[i])
|
||||
if i == len(layers) - 1:
|
||||
print_fn('=' * line_length)
|
||||
else:
|
||||
print_fn('_' * line_length)
|
||||
|
||||
if hasattr(model, '_collected_trainable_weights'):
|
||||
trainable_count = count_params(model._collected_trainable_weights)
|
||||
else:
|
||||
trainable_count = count_params(model.trainable_weights)
|
||||
|
||||
non_trainable_count = count_params(model.non_trainable_weights)
|
||||
|
||||
print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
|
||||
print_fn('Trainable params: {:,}'.format(trainable_count))
|
||||
print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
|
||||
print_fn('_' * line_length)
|
||||
|
||||
|
||||
def gather_trainable_weights(trainable, sub_layers, extra_variables):
|
||||
"""Lists the trainable weights for an object with sub-layers.
|
||||
|
||||
Args:
|
||||
trainable: Whether the object collecting the variables is trainable.
|
||||
sub_layers: A flat list of Layer objects owned by this object, to collect
|
||||
variables from.
|
||||
extra_variables: Any extra variables to include. Their `.trainable` property
|
||||
is used to categorize them.
|
||||
|
||||
Returns:
|
||||
A list of collected trainable weights/variables.
|
||||
"""
|
||||
if not trainable:
|
||||
return []
|
||||
weights = []
|
||||
for layer in sub_layers:
|
||||
weights += layer.trainable_weights
|
||||
trainable_extra_variables = [
|
||||
v for v in extra_variables if v.trainable]
|
||||
return weights + trainable_extra_variables
|
||||
|
||||
|
||||
def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
|
||||
"""Lists the non-trainable weights for an object with sub-layers.
|
||||
|
||||
Args:
|
||||
trainable: Whether the object collecting the variables is trainable.
|
||||
sub_layers: A flat list of Layer objects owned by this object, to collect
|
||||
variables from.
|
||||
extra_variables: Any extra variables to include. Their `.trainable` property
|
||||
is used to categorize them.
|
||||
|
||||
Returns:
|
||||
A list of collected non-trainable weights/variables.
|
||||
"""
|
||||
trainable_extra_variables = []
|
||||
non_trainable_extra_variables = []
|
||||
for v in extra_variables:
|
||||
if v.trainable:
|
||||
trainable_extra_variables.append(v)
|
||||
else:
|
||||
non_trainable_extra_variables.append(v)
|
||||
weights = []
|
||||
for layer in sub_layers:
|
||||
weights += layer.non_trainable_weights
|
||||
if not trainable:
|
||||
trainable_weights = []
|
||||
for layer in sub_layers:
|
||||
trainable_weights += layer.trainable_weights
|
||||
return (trainable_weights + trainable_extra_variables
|
||||
+ weights + non_trainable_extra_variables)
|
||||
return weights + non_trainable_extra_variables
|
||||
|
||||
|
||||
@deprecation.deprecated('2020-06-23',
|
||||
'The Theano kernel format is legacy; '
|
||||
'this utility will be removed.')
|
||||
def convert_all_kernels_in_model(model):
|
||||
"""Converts all convolution kernels in a model from Theano to TensorFlow.
|
||||
|
||||
Also works from TensorFlow to Theano.
|
||||
|
||||
This is used for converting legacy Theano-saved model files.
|
||||
|
||||
Arguments:
|
||||
model: target model for the conversion.
|
||||
"""
|
||||
# Note: SeparableConvolution not included
|
||||
# since only supported by TF.
|
||||
conv_classes = {
|
||||
'Conv1D',
|
||||
'Conv2D',
|
||||
'Conv3D',
|
||||
'Conv2DTranspose',
|
||||
}
|
||||
to_assign = []
|
||||
for layer in model.layers:
|
||||
if layer.__class__.__name__ in conv_classes:
|
||||
original_kernel = K.get_value(layer.kernel)
|
||||
converted_kernel = convert_kernel(original_kernel)
|
||||
to_assign.append((layer.kernel, converted_kernel))
|
||||
K.batch_set_value(to_assign)
|
||||
|
||||
|
||||
def convert_dense_weights_data_format(dense,
|
||||
previous_feature_map_shape,
|
||||
target_data_format='channels_first'):
|
||||
"""Utility useful when changing a convnet's `data_format`.
|
||||
|
||||
When porting the weights of a convnet from one data format to the other,
|
||||
if the convnet includes a `Flatten` layer
|
||||
(applied to the last convolutional feature map)
|
||||
followed by a `Dense` layer, the weights of that `Dense` layer
|
||||
should be updated to reflect the new dimension ordering.
|
||||
|
||||
Arguments:
|
||||
dense: The target `Dense` layer.
|
||||
previous_feature_map_shape: A shape tuple of 3 integers,
|
||||
e.g. `(512, 7, 7)`. The shape of the convolutional
|
||||
feature map right before the `Flatten` layer that
|
||||
came before the target `Dense` layer.
|
||||
target_data_format: One of "channels_last", "channels_first".
|
||||
Set it "channels_last"
|
||||
if converting a "channels_first" model to "channels_last",
|
||||
or reciprocally.
|
||||
"""
|
||||
assert target_data_format in {'channels_last', 'channels_first'}
|
||||
kernel, bias = dense.get_weights()
|
||||
for i in range(kernel.shape[1]):
|
||||
if target_data_format == 'channels_first':
|
||||
c, h, w = previous_feature_map_shape
|
||||
original_fm_shape = (h, w, c)
|
||||
ki = kernel[:, i].reshape(original_fm_shape)
|
||||
ki = np.transpose(ki, (2, 0, 1)) # last -> first
|
||||
else:
|
||||
h, w, c = previous_feature_map_shape
|
||||
original_fm_shape = (c, h, w)
|
||||
ki = kernel[:, i].reshape(original_fm_shape)
|
||||
ki = np.transpose(ki, (1, 2, 0)) # first -> last
|
||||
kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
|
||||
dense.set_weights([kernel, bias])
|
||||
|
||||
|
||||
def is_builtin_layer(layer):
|
||||
if not getattr(layer, '_keras_api_names', None):
|
||||
return False
|
||||
|
||||
# Subclasses of `Layer` that are not exported inherit the export name
|
||||
# of the base layer class.
|
||||
return (layer._keras_api_names != ('keras.layers.Layer',) and
|
||||
layer._keras_api_names_v1 != ('keras.layers.Layer',))
|
@ -1,524 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorFlow-related utilities."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond as smart_module
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.frozen_keras import backend as K
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
|
||||
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
|
||||
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
|
||||
|
||||
If `pred` is a bool or has a constant value, we return either `true_fn()`
|
||||
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
|
||||
|
||||
Arguments:
|
||||
pred: A scalar determining whether to return the result of `true_fn` or
|
||||
`false_fn`.
|
||||
true_fn: The callable to be performed if pred is true.
|
||||
false_fn: The callable to be performed if pred is false.
|
||||
name: Optional name prefix when using `tf.cond`.
|
||||
|
||||
Returns:
|
||||
Tensors returned by the call to either `true_fn` or `false_fn`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `true_fn` or `false_fn` is not callable.
|
||||
"""
|
||||
if isinstance(pred, variables.Variable):
|
||||
return control_flow_ops.cond(
|
||||
pred, true_fn=true_fn, false_fn=false_fn, name=name)
|
||||
return smart_module.smart_cond(
|
||||
pred, true_fn=true_fn, false_fn=false_fn, name=name)
|
||||
|
||||
|
||||
def constant_value(pred):
|
||||
"""Return the bool value for `pred`, or None if `pred` had a dynamic value.
|
||||
|
||||
Arguments:
|
||||
pred: A scalar, either a Python bool or a TensorFlow boolean variable
|
||||
or tensor, or the Python integer 1 or 0.
|
||||
|
||||
Returns:
|
||||
True or False if `pred` has a constant boolean value, None otherwise.
|
||||
|
||||
Raises:
|
||||
TypeError: If `pred` is not a Variable, Tensor or bool, or Python
|
||||
integer 1 or 0.
|
||||
"""
|
||||
# Allow integer booleans.
|
||||
if isinstance(pred, int):
|
||||
if pred == 1:
|
||||
pred = True
|
||||
elif pred == 0:
|
||||
pred = False
|
||||
|
||||
if isinstance(pred, variables.Variable):
|
||||
return None
|
||||
return smart_module.smart_constant_value(pred)
|
||||
|
||||
|
||||
def is_tensor_or_tensor_list(v):
|
||||
v = nest.flatten(v)
|
||||
if v and isinstance(v[0], ops.Tensor):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def get_reachable_from_inputs(inputs, targets=None):
|
||||
"""Returns the set of tensors/ops reachable from `inputs`.
|
||||
|
||||
Stops if all targets have been found (target is optional).
|
||||
|
||||
Only valid in Symbolic mode, not Eager mode.
|
||||
|
||||
Args:
|
||||
inputs: List of tensors.
|
||||
targets: List of tensors.
|
||||
|
||||
Returns:
|
||||
A set of tensors reachable from the inputs (includes the inputs themselves).
|
||||
"""
|
||||
inputs = nest.flatten(inputs, expand_composites=True)
|
||||
reachable = object_identity.ObjectIdentitySet(inputs)
|
||||
if targets:
|
||||
remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets))
|
||||
queue = inputs[:]
|
||||
|
||||
while queue:
|
||||
x = queue.pop()
|
||||
if isinstance(x, tuple(_user_convertible_tensor_types)):
|
||||
# Can't find consumers of user-specific types.
|
||||
continue
|
||||
|
||||
if isinstance(x, ops.Operation):
|
||||
outputs = x.outputs[:] or []
|
||||
outputs += x._control_outputs # pylint: disable=protected-access
|
||||
elif isinstance(x, variables.Variable):
|
||||
try:
|
||||
outputs = [x.op]
|
||||
except AttributeError:
|
||||
# Variables can be created in an Eager context.
|
||||
outputs = []
|
||||
elif tensor_util.is_tensor(x):
|
||||
outputs = x.consumers()
|
||||
else:
|
||||
raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
|
||||
|
||||
for y in outputs:
|
||||
if y not in reachable:
|
||||
reachable.add(y)
|
||||
if targets:
|
||||
remaining_targets.discard(y)
|
||||
queue.insert(0, y)
|
||||
|
||||
if targets and not remaining_targets:
|
||||
return reachable
|
||||
|
||||
return reachable
|
||||
|
||||
|
||||
# This function needs access to private functions of `nest`.
|
||||
# pylint: disable=protected-access
|
||||
def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
|
||||
"""Maps the atomic elements of a nested structure.
|
||||
|
||||
Arguments:
|
||||
is_atomic_fn: A function that determines if an element of `nested` is
|
||||
atomic.
|
||||
map_fn: The function to apply to atomic elements of `nested`.
|
||||
nested: A nested structure.
|
||||
|
||||
Returns:
|
||||
The nested structure, with atomic elements mapped according to `map_fn`.
|
||||
|
||||
Raises:
|
||||
ValueError: If an element that is neither atomic nor a sequence is
|
||||
encountered.
|
||||
"""
|
||||
if is_atomic_fn(nested):
|
||||
return map_fn(nested)
|
||||
|
||||
# Recursively convert.
|
||||
if not nest.is_sequence(nested):
|
||||
raise ValueError(
|
||||
'Received non-atomic and non-sequence element: {}'.format(nested))
|
||||
if nest._is_mapping(nested):
|
||||
values = [nested[k] for k in nest._sorted(nested)]
|
||||
else:
|
||||
values = nested
|
||||
mapped_values = [
|
||||
map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
|
||||
]
|
||||
return nest._sequence_like(nested, mapped_values)
|
||||
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def convert_shapes(input_shape, to_tuples=True):
|
||||
"""Converts nested shape representations to desired format.
|
||||
|
||||
Performs:
|
||||
|
||||
TensorShapes -> tuples if `to_tuples=True`.
|
||||
tuples of int or None -> TensorShapes if `to_tuples=False`.
|
||||
|
||||
Valid objects to be converted are:
|
||||
- TensorShapes
|
||||
- tuples with elements of type int or None.
|
||||
- ints
|
||||
- None
|
||||
|
||||
Arguments:
|
||||
input_shape: A nested structure of objects to be converted to TensorShapes.
|
||||
to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts
|
||||
all tuples representing shapes to TensorShapes.
|
||||
|
||||
Returns:
|
||||
Nested structure of shapes in desired format.
|
||||
|
||||
Raises:
|
||||
ValueError: when the input tensor shape can't be converted to tuples, eg
|
||||
unknown tensor shape.
|
||||
"""
|
||||
|
||||
def _is_shape_component(value):
|
||||
return value is None or isinstance(value, (int, tensor_shape.Dimension))
|
||||
|
||||
def _is_atomic_shape(input_shape):
|
||||
# Ex: TensorShape or (None, 10, 32) or 5 or `None`
|
||||
if _is_shape_component(input_shape):
|
||||
return True
|
||||
if isinstance(input_shape, tensor_shape.TensorShape):
|
||||
return True
|
||||
if (isinstance(input_shape, (tuple, list)) and
|
||||
all(_is_shape_component(ele) for ele in input_shape)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _convert_shape(input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
if to_tuples:
|
||||
input_shape = tuple(input_shape.as_list())
|
||||
return input_shape
|
||||
|
||||
return map_structure_with_atomic(_is_atomic_shape, _convert_shape,
|
||||
input_shape)
|
||||
|
||||
|
||||
class ListWrapper(object):
|
||||
"""A wrapper for lists to be treated as elements for `nest`."""
|
||||
|
||||
def __init__(self, list_to_wrap):
|
||||
self._list = list_to_wrap
|
||||
|
||||
def as_list(self):
|
||||
return self._list
|
||||
|
||||
|
||||
def convert_inner_node_data(nested, wrap=False):
|
||||
"""Either wraps or unwraps innermost node data lists in `ListWrapper` objects.
|
||||
|
||||
Arguments:
|
||||
nested: A nested data structure.
|
||||
wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`,
|
||||
unwraps `ListWrapper` objects into lists.
|
||||
|
||||
Returns:
|
||||
Structure of same type as nested, with lists wrapped/unwrapped.
|
||||
"""
|
||||
|
||||
def _is_serialized_node_data(nested):
|
||||
# Node data can be of form `[layer_name, node_id, tensor_id]` or
|
||||
# `[layer_name, node_id, tensor_id, kwargs]`.
|
||||
if (isinstance(nested, list) and (len(nested) in [3, 4]) and
|
||||
isinstance(nested[0], six.string_types)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_atomic_nested(nested):
|
||||
"""Returns `True` if `nested` is a list representing node data."""
|
||||
if isinstance(nested, ListWrapper):
|
||||
return True
|
||||
if _is_serialized_node_data(nested):
|
||||
return True
|
||||
return not nest.is_sequence(nested)
|
||||
|
||||
def _convert_object_or_list(nested):
|
||||
"""Convert b/t `ListWrapper` object and list representations."""
|
||||
if wrap:
|
||||
if isinstance(nested, ListWrapper):
|
||||
return nested
|
||||
if _is_serialized_node_data(nested):
|
||||
return ListWrapper(nested)
|
||||
return nested
|
||||
else:
|
||||
if isinstance(nested, ListWrapper):
|
||||
return nested.as_list()
|
||||
return nested
|
||||
|
||||
return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list,
|
||||
nested)
|
||||
|
||||
|
||||
def shape_type_conversion(fn):
|
||||
"""Decorator that handles tuple/TensorShape conversion.
|
||||
|
||||
Used in `compute_output_shape` and `build`.
|
||||
|
||||
Arguments:
|
||||
fn: function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
def wrapper(instance, input_shape):
|
||||
# Pass shapes as tuples to `fn`
|
||||
# This preserves compatibility with external Keras.
|
||||
if input_shape is not None:
|
||||
input_shape = convert_shapes(input_shape, to_tuples=True)
|
||||
output_shape = fn(instance, input_shape)
|
||||
# Return shapes from `fn` as TensorShapes.
|
||||
if output_shape is not None:
|
||||
output_shape = convert_shapes(output_shape, to_tuples=False)
|
||||
return output_shape
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def are_all_symbolic_tensors(tensors):
|
||||
return all(is_symbolic_tensor(tensor) for tensor in tensors)
|
||||
|
||||
|
||||
_user_convertible_tensor_types = set()
|
||||
|
||||
|
||||
def is_symbolic_tensor(tensor):
|
||||
"""Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
|
||||
|
||||
A Variable can be seen as either: it is considered symbolic
|
||||
when we are in a graph scope, and eager when we are in an eager scope.
|
||||
|
||||
Arguments:
|
||||
tensor: A tensor instance to test.
|
||||
|
||||
Returns:
|
||||
True for symbolic tensors, False for eager tensors.
|
||||
"""
|
||||
if isinstance(tensor, tuple(_user_convertible_tensor_types)):
|
||||
tensor = ops.convert_to_tensor_or_composite(tensor)
|
||||
if isinstance(tensor, variables.Variable):
|
||||
# Variables that are output of a Keras Layer in Functional API mode
|
||||
# should be considered symbolic.
|
||||
# TODO(omalleyt): We need a better way to check this in order to
|
||||
# enable `run_eagerly=True` for Models containing Layers that
|
||||
# return Variables as outputs.
|
||||
return (getattr(tensor, '_keras_history', False) or
|
||||
not context.executing_eagerly())
|
||||
if isinstance(tensor, composite_tensor.CompositeTensor):
|
||||
component_tensors = nest.flatten(tensor, expand_composites=True)
|
||||
return any(hasattr(t, 'graph') for t in component_tensors)
|
||||
if isinstance(tensor, ops.Tensor):
|
||||
return hasattr(tensor, 'graph')
|
||||
return False
|
||||
|
||||
|
||||
def register_symbolic_tensor_type(cls):
|
||||
"""Allows users to specify types regarded as symbolic `Tensor`s.
|
||||
|
||||
Used in conjunction with `tf.register_tensor_conversion_function`, calling
|
||||
`tf.keras.utils.register_symbolic_tensor_type(cls)` allows non-`Tensor`
|
||||
objects to be plumbed through Keras layers.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# One-time setup.
|
||||
class Foo(object):
|
||||
def __init__(self, input_):
|
||||
self._input = input_
|
||||
def value(self):
|
||||
return tf.constant(42.)
|
||||
|
||||
tf.register_tensor_conversion_function(
|
||||
Foo, lambda x, *args, **kwargs: x.value())
|
||||
|
||||
tf.keras.utils.register_symbolic_tensor_type(Foo)
|
||||
|
||||
# User-land.
|
||||
layer = tf.keras.layers.Lambda(lambda input_: Foo(input_))
|
||||
```
|
||||
|
||||
Arguments:
|
||||
cls: A `class` type which shall be regarded as a symbolic `Tensor`.
|
||||
"""
|
||||
global _user_convertible_tensor_types
|
||||
_user_convertible_tensor_types.add(cls)
|
||||
|
||||
|
||||
def type_spec_from_value(value):
|
||||
"""Grab type_spec without converting array-likes to tensors."""
|
||||
if isinstance(value, composite_tensor.CompositeTensor):
|
||||
return value._type_spec # pylint: disable=protected-access
|
||||
# Get a TensorSpec for array-like data without
|
||||
# converting the data to a Tensor
|
||||
if hasattr(value, 'shape') and hasattr(value, 'dtype'):
|
||||
return tensor_spec.TensorSpec(value.shape, value.dtype)
|
||||
else:
|
||||
return type_spec.type_spec_from_value(value)
|
||||
|
||||
|
||||
def is_tensor_or_variable(x):
|
||||
return tensor_util.is_tensor(x) or isinstance(x, variables.Variable)
|
||||
|
||||
|
||||
def assert_no_legacy_layers(layers):
|
||||
"""Prevent tf.layers.Layers from being used with Keras.
|
||||
|
||||
Certain legacy layers inherit from their keras analogs; however they are
|
||||
not supported with keras and can lead to subtle and hard to diagnose bugs.
|
||||
|
||||
Args:
|
||||
layers: A list of layers to check
|
||||
|
||||
Raises:
|
||||
TypeError: If any elements of layers are tf.layers.Layers
|
||||
"""
|
||||
|
||||
# isinstance check for tf.layers.Layer introduces a circular dependency.
|
||||
legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)]
|
||||
if legacy_layers:
|
||||
layer_str = '\n'.join(' ' + str(l) for l in legacy_layers)
|
||||
raise TypeError(
|
||||
'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a '
|
||||
'framework (for instance using the Network, Model, or Sequential '
|
||||
'classes), please use the tf.keras.layers implementation instead. '
|
||||
'(Or, if writing custom layers, subclass from tf.keras.layers rather '
|
||||
'than tf.layers)'.format(layer_str))
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def maybe_init_scope(layer):
|
||||
"""Open an `init_scope` if in V2 mode and using the keras graph.
|
||||
|
||||
Arguments:
|
||||
layer: The Layer/Model that is currently active.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
# Don't open an init_scope in V1 mode or when using legacy tf.layers.
|
||||
if (ops.executing_eagerly_outside_functions() and
|
||||
getattr(layer, '_keras_style', True)):
|
||||
with ops.init_scope():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def graph_context_for_symbolic_tensors(*args, **kwargs):
|
||||
"""Returns graph context manager if any of the inputs is a symbolic tensor."""
|
||||
if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())):
|
||||
with K.get_graph().as_default():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def dataset_is_infinite(dataset):
|
||||
"""True if the passed dataset is infinite."""
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
return math_ops.equal(
|
||||
cardinality.cardinality(dataset), cardinality.INFINITE)
|
||||
else:
|
||||
dataset_size = K.get_session().run(cardinality.cardinality(dataset))
|
||||
return dataset_size == cardinality.INFINITE
|
||||
|
||||
|
||||
def get_tensor_spec(t, dynamic_batch=False, name=None):
|
||||
"""Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
|
||||
if isinstance(t, type_spec.TypeSpec):
|
||||
spec = t
|
||||
elif isinstance(t, composite_tensor.CompositeTensor):
|
||||
# TODO(b/148821952): Should these specs have a name attr?
|
||||
spec = t._type_spec # pylint: disable=protected-access
|
||||
elif hasattr(t, 'shape') and hasattr(t, 'dtype'):
|
||||
spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
|
||||
else:
|
||||
return None # Allow non-Tensors to pass through.
|
||||
|
||||
if not dynamic_batch:
|
||||
return spec
|
||||
|
||||
dynamic_batch_spec = copy.deepcopy(spec)
|
||||
# RaggedTensorSpec only has a private _shape.
|
||||
shape = dynamic_batch_spec._shape.as_list() # pylint: disable=protected-access
|
||||
if shape:
|
||||
shape[0] = None
|
||||
dynamic_batch_spec._shape = tensor_shape.TensorShape(shape) # pylint: disable=protected-access
|
||||
return dynamic_batch_spec
|
||||
|
||||
|
||||
def to_numpy_or_python_type(tensors):
|
||||
"""Converts a structure of `Tensor`s to `NumPy` arrays or Python scalar types.
|
||||
|
||||
For each tensor, it calls `tensor.numpy()`. If the result is a scalar value,
|
||||
it converts it to a Python type, such as a float or int, by calling
|
||||
`result.item()`.
|
||||
|
||||
Numpy scalars are converted, as Python types are often more convenient to deal
|
||||
with. This is especially useful for bfloat16 Numpy scalars, which don't
|
||||
support as many operations as other Numpy values.
|
||||
|
||||
Args:
|
||||
tensors: A structure of tensors.
|
||||
|
||||
Returns:
|
||||
`tensors`, but scalar tensors are converted to Python types and non-scalar
|
||||
tensors are converted to Numpy arrays.
|
||||
"""
|
||||
def _to_single_numpy_or_python_type(t):
|
||||
if isinstance(t, ops.Tensor):
|
||||
x = t.numpy()
|
||||
return x.item() if np.ndim(x) == 0 else x
|
||||
return t # Don't turn ragged or sparse tensors to NumPy.
|
||||
|
||||
return nest.map_structure(_to_single_numpy_or_python_type, tensors)
|
||||
|
@ -1,162 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Keras TF utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.frozen_keras.utils import tf_utils
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TestIsSymbolicTensor(test.TestCase):
|
||||
|
||||
def test_default_behavior(self):
|
||||
if context.executing_eagerly():
|
||||
self.assertFalse(tf_utils.is_symbolic_tensor(
|
||||
variables.Variable(name='blah', initial_value=0.)))
|
||||
self.assertFalse(
|
||||
tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
|
||||
self.assertFalse(tf_utils.is_symbolic_tensor(
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
|
||||
else:
|
||||
self.assertTrue(tf_utils.is_symbolic_tensor(
|
||||
variables.Variable(name='blah', initial_value=0.)))
|
||||
self.assertTrue(tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
|
||||
self.assertTrue(tf_utils.is_symbolic_tensor(
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
|
||||
|
||||
def test_works_with_registered(self):
|
||||
|
||||
class CustomClass(object):
|
||||
|
||||
def value(self):
|
||||
return ops.convert_to_tensor_v2(42.)
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
CustomClass, lambda value, **_: value.value())
|
||||
|
||||
tf_utils.register_symbolic_tensor_type(CustomClass)
|
||||
|
||||
if context.executing_eagerly():
|
||||
self.assertFalse(tf_utils.is_symbolic_tensor(
|
||||
variables.Variable(name='blah', initial_value=0.)))
|
||||
self.assertFalse(
|
||||
tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
|
||||
self.assertFalse(tf_utils.is_symbolic_tensor(
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
|
||||
self.assertFalse(tf_utils.is_symbolic_tensor(CustomClass()))
|
||||
else:
|
||||
self.assertTrue(tf_utils.is_symbolic_tensor(
|
||||
variables.Variable(name='blah', initial_value=0.)))
|
||||
self.assertTrue(tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
|
||||
self.assertTrue(tf_utils.is_symbolic_tensor(
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
|
||||
self.assertTrue(tf_utils.is_symbolic_tensor(CustomClass()))
|
||||
|
||||
def test_enables_nontensor_plumbing(self):
|
||||
self.skipTest('Sequential model will check layer instance type and fail.')
|
||||
|
||||
if context.executing_eagerly():
|
||||
self.skipTest('`compile` functionality changed.')
|
||||
# Setup.
|
||||
|
||||
class Foo(object):
|
||||
|
||||
def __init__(self, input_):
|
||||
self._input = input_
|
||||
self.value = ops.convert_to_tensor_v2([[42.]])
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.value.dtype
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
Foo, lambda x, *args, **kwargs: x.value)
|
||||
tf_utils.register_symbolic_tensor_type(Foo)
|
||||
|
||||
class PlumbingLayer(keras.layers.Lambda):
|
||||
|
||||
def __init__(self, fn, **kwargs):
|
||||
def _fn(*fargs, **fkwargs):
|
||||
d = fn(*fargs, **fkwargs)
|
||||
x = ops.convert_to_tensor_v2(d)
|
||||
d.shape = x.shape
|
||||
d.get_shape = x.get_shape
|
||||
return d, x
|
||||
super(PlumbingLayer, self).__init__(_fn, **kwargs)
|
||||
self._enter_dunder_call = False
|
||||
|
||||
def __call__(self, inputs, *args, **kwargs):
|
||||
self._enter_dunder_call = True
|
||||
d, _ = super(PlumbingLayer, self).__call__(inputs, *args, **kwargs)
|
||||
self._enter_dunder_call = False
|
||||
return d
|
||||
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
d, v = super(PlumbingLayer, self).call(inputs, *args, **kwargs)
|
||||
if self._enter_dunder_call:
|
||||
return d, v
|
||||
return d
|
||||
|
||||
# User-land.
|
||||
model = keras.Sequential([
|
||||
keras.layers.InputLayer((1,)),
|
||||
PlumbingLayer(Foo), # Makes a `Foo` object.
|
||||
])
|
||||
# Let's ensure Keras graph history is preserved by composing the models.
|
||||
model = keras.Model(model.inputs, model(model.outputs))
|
||||
# Now we instantiate the model and verify we have a `Foo` object, not a
|
||||
# `Tensor`.
|
||||
y = model(ops.convert_to_tensor_v2([[7.]]))
|
||||
self.assertIsInstance(y, Foo)
|
||||
# Confirm that (custom) loss sees `Foo` instance, not Tensor.
|
||||
obtained_prediction_box = [None]
|
||||
def custom_loss(y_obs, y_pred):
|
||||
del y_obs
|
||||
obtained_prediction_box[0] = y_pred
|
||||
return y_pred
|
||||
# Apparently `compile` calls the loss function enough to trigger the
|
||||
# side-effect.
|
||||
model.compile('SGD', loss=custom_loss)
|
||||
self.assertIsInstance(obtained_prediction_box[0], Foo)
|
||||
|
||||
|
||||
class ConvertInnerNodeDataTest(test.TestCase):
|
||||
|
||||
def test_convert_inner_node_data(self):
|
||||
data = tf_utils.convert_inner_node_data((tf_utils.ListWrapper(['l', 2, 3]),
|
||||
tf_utils.ListWrapper(['l', 5, 6])))
|
||||
self.assertEqual(data, (['l', 2, 3], ['l', 5, 6]))
|
||||
|
||||
data = tf_utils.convert_inner_node_data(((['l', 2, 3], ['l', 5, 6])),
|
||||
wrap=True)
|
||||
self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -44,7 +44,6 @@ py_library(
|
||||
"//tensorflow/python/distribute:input_lib",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/eager:monitoring",
|
||||
"//tensorflow/python/frozen_keras/engine:legacy_base_layer",
|
||||
"//tensorflow/python/keras:activations",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:callbacks",
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from tensorflow.python.frozen_keras.engine import legacy_base_layer
|
||||
from tensorflow.python.keras import layers as layer_module
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
@ -169,8 +168,7 @@ class Sequential(training.Model):
|
||||
if isinstance(origin_layer, input_layer.InputLayer):
|
||||
layer = origin_layer
|
||||
|
||||
if not isinstance(layer,
|
||||
(base_layer.Layer, legacy_base_layer.LegacyBaseLayer)):
|
||||
if not isinstance(layer, base_layer.Layer):
|
||||
raise TypeError('The added layer must be '
|
||||
'an instance of class Layer. '
|
||||
'Found: ' + str(layer))
|
||||
|
@ -122,7 +122,6 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/eager:eager_pip",
|
||||
"//tensorflow/python/frozen_keras/engine:legacy_base_layer",
|
||||
"//tensorflow/python/keras:combinations",
|
||||
"//tensorflow/python/keras/layers/preprocessing:preprocessing_test_utils",
|
||||
"//tensorflow/python/keras/distribute:distribute_strategy_test_lib",
|
||||
|
Loading…
Reference in New Issue
Block a user