Remove all frozen copy of Keras code.

PiperOrigin-RevId: 300795615
Change-Id: Ibe8e69cddeb992aaa00a87da4c6543c8804f7b14
This commit is contained in:
Scott Zhu 2020-03-13 11:37:29 -07:00 committed by TensorFlower Gardener
parent 4fcd935d48
commit 9d297ebabc
29 changed files with 1 additions and 18363 deletions

View File

@ -1,175 +0,0 @@
load("//tensorflow:tensorflow.bzl", "tf_py_test")
package(
default_visibility = ["//tensorflow:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
#TODO(scottzhu): Cleanup all the deps to python/keras
py_library(
name = "frozen_keras",
deps = [
":backend",
":backend_config",
":constraint",
":initializers",
":regularizers",
"//tensorflow/python/frozen_keras/engine:legacy_base_layer",
],
)
py_library(
name = "activations",
srcs = ["activations.py"],
deps = [
":backend",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python/frozen_keras/utils:generic_utils",
"@six_archive//:six",
],
)
py_library(
name = "backend",
srcs = ["backend.py"],
deps = [
":backend_config",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:config",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:ctc_ops",
"//tensorflow/python:device",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:func_graph",
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:image_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:logging_ops",
"//tensorflow/python:map_fn",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:platform",
"//tensorflow/python:random_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:state_ops",
"//tensorflow/python:tensor_array_grad",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:tf2",
"//tensorflow/python:training_lib",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_coordinator",
"//tensorflow/python/distribute:distribute_coordinator_context",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:lift_to_graph",
"//tensorflow/python/ops/ragged:ragged_concat_ops",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//third_party/py/numpy",
],
)
py_library(
name = "backend_config",
srcs = ["backend_config.py"],
deps = [],
)
py_library(
name = "constraint",
srcs = ["constraints.py"],
deps = [
":backend",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/frozen_keras/utils:generic_utils",
"@six_archive//:six",
],
)
py_library(
name = "initializers",
srcs = ["initializers.py"],
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:init_ops",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python:tf2",
"//tensorflow/python/frozen_keras/utils:generic_utils",
"@six_archive//:six",
],
)
py_library(
name = "regularizers",
srcs = ["regularizers.py"],
deps = [
":backend",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python/frozen_keras/utils:generic_utils",
"@six_archive//:six",
],
)
tf_py_test(
name = "backend_test",
size = "medium",
srcs = ["backend_test.py"],
python_version = "PY3",
shard_count = 4,
deps = [
":backend",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:config",
"//tensorflow/python:errors",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:nn",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/frozen_keras/engine:base_layer_utils",
"//tensorflow/python/keras:combinations",
"//tensorflow/python/keras/engine",
"//tensorflow/python/keras/layers:advanced_activations",
"//tensorflow/python/keras/layers:normalization",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "backend_config_test",
size = "medium",
srcs = ["backend_config_test.py"],
python_version = "PY3",
deps = [
":backend",
":backend_config",
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras:combinations",
],
)

View File

@ -1,15 +0,0 @@
# DO NOT USE
Everything under this package is for internal usage, and only serves a
dependency from legacy TF v1 APIs that relies on Keras. Any active development
should happen in third_party/tensorflow/python/keras instead.
## Background
In order to build a more modular Tensorflow and Keras, we decided to split the
Keras code into its own repository. Having TensorFlow depend on
Keras is a red flag as it is a reverse dependency. As some legacy TF V1 APIs
are using Keras classes as base classes, like `Layer`, we decided to keep a copy
of the trimmed Keras code to resolve the reverse dependency. This will also
ensure the stability of the TF V1 API will be not affected by the active
development of the Keras project.

View File

@ -1,453 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Built-in activation functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.frozen_keras import backend as K
from tensorflow.python.frozen_keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.frozen_keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
# b/123041942
# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
# layers, it gets serialized as 'softmax_v2' instead of 'softmax' as the
# internal method name is returned in serialization. This results in errors in
# model exporting and loading as Keras can't find any activation function with
# the name of `softmax_v2`.
# This dict maps the activation function name from its v2 version to its
# canonical name.
_TF_ACTIVATIONS_V2 = {
'softmax_v2': 'softmax',
}
def softmax(x, axis=-1):
"""Softmax converts a real vector to a vector of categorical probabilities.
The elements of the output vector are in range (0, 1) and sum to 1.
Each vector is handled independently. The `axis` argument sets which axis
of the input the function is applied along.
Softmax is often used as the activation for the last
layer of a classification network because the result could be interpreted as
a probability distribution.
The softmax of each vector x is calculated by `exp(x)/tf.reduce_sum(exp(x))`.
The input values in are the log-odds of the resulting probability.
Arguments:
x : Input tensor.
axis: Integer, axis along which the softmax normalization is applied.
Returns:
Tensor, output of softmax transformation (all values are non-negative
and sum to 1).
Raises:
ValueError: In case `dim(x) == 1`.
"""
ndim = K.ndim(x)
if ndim == 2:
return nn.softmax(x)
elif ndim > 2:
e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True))
s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
return e / s
else:
raise ValueError('Cannot apply softmax to a tensor that is 1D. '
'Received input: %s' % (x,))
def elu(x, alpha=1.0):
"""Exponential linear unit.
Arguments:
x: Input tensor.
alpha: A scalar, slope of negative section.
Returns:
The exponential linear activation: `x` if `x > 0` and
`alpha * (exp(x)-1)` if `x < 0`.
Reference:
- [Fast and Accurate Deep Network Learning by Exponential
Linear Units (ELUs)](https://arxiv.org/abs/1511.07289)
"""
return K.elu(x, alpha)
def selu(x):
"""Scaled Exponential Linear Unit (SELU).
The Scaled Exponential Linear Unit (SELU) activation function is:
`scale * x` if `x > 0` and `scale * alpha * (exp(x) - 1)` if `x < 0`
where `alpha` and `scale` are pre-defined constants
(`alpha = 1.67326324`
and `scale = 1.05070098`).
The SELU activation function multiplies `scale` > 1 with the
`[elu](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/activations/elu)`
(Exponential Linear Unit (ELU)) to ensure a slope larger than one
for positive net inputs.
The values of `alpha` and `scale` are
chosen so that the mean and variance of the inputs are preserved
between two consecutive layers as long as the weights are initialized
correctly (see [`lecun_normal` initialization]
(https://www.tensorflow.org/api_docs/python/tf/keras/initializers/lecun_normal))
and the number of inputs is "large enough"
(see references for more information).
![]https://cdn-images-1.medium.com/max/1600/1*m0e8lZU_Zrkh4ESfQkY2Pw.png
(Courtesy: Blog on Towards DataScience at
https://towardsdatascience.com/selu-make-fnns-great-again-snn-8d61526802a9)
Example Usage:
>>> n_classes = 10 #10-class problem
>>> from tensorflow.python.keras.layers import Dense
>>> model = tf.keras.Sequential()
>>> model.add(Dense(64, kernel_initializer='lecun_normal',
... activation='selu', input_shape=(28, 28, 1)))
>>> model.add(Dense(32, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(Dense(16, kernel_initializer='lecun_normal',
... activation='selu'))
>>> model.add(Dense(n_classes, activation='softmax'))
Arguments:
x: A tensor or variable to compute the activation function for.
Returns:
The scaled exponential unit activation: `scale * elu(x, alpha)`.
# Note
- To be used together with the initialization "[lecun_normal]
(https://www.tensorflow.org/api_docs/python/tf/keras/initializers/lecun_normal)".
- To be used together with the dropout variant "[AlphaDropout]
(https://www.tensorflow.org/api_docs/python/tf/keras/layers/AlphaDropout)".
References:
[Self-Normalizing Neural Networks (Klambauer et al, 2017)]
(https://arxiv.org/abs/1706.02515)
"""
return nn.selu(x)
def softplus(x):
"""Softplus activation function.
Arguments:
x: Input tensor.
Returns:
The softplus activation: `log(exp(x) + 1)`.
"""
return nn.softplus(x)
def softsign(x):
"""Softsign activation function.
Arguments:
x: Input tensor.
Returns:
The softsign activation: `x / (abs(x) + 1)`.
"""
return nn.softsign(x)
def swish(x):
"""Swish activation function.
Arguments:
x: Input tensor.
Returns:
The swish activation applied to `x`.
"""
return nn.swish(x)
def relu(x, alpha=0., max_value=None, threshold=0):
"""Applies the rectified linear unit activation function.
With default values, this returns the standard ReLU activation:
`max(x, 0)`, the element-wise maximum of 0 and the input tensor.
Modifying default parameters allows you to use non-zero thresholds,
change the max value of the activation,
and to use a non-zero multiple of the input for values below the threshold.
For example:
>>> foo = tf.constant([-10, -5, 0.0, 5, 10], dtype = tf.float32)
>>> tf.keras.activations.relu(foo).numpy()
array([ 0., 0., 0., 5., 10.], dtype=float32)
>>> tf.keras.activations.relu(foo, alpha=0.5).numpy()
array([-5. , -2.5, 0. , 5. , 10. ], dtype=float32)
>>> tf.keras.activations.relu(foo, max_value=5).numpy()
array([0., 0., 0., 5., 5.], dtype=float32)
>>> tf.keras.activations.relu(foo, threshold=5).numpy()
array([-0., -0., 0., 0., 10.], dtype=float32)
Arguments:
x: Input `tensor` or `variable`.
alpha: A `float` that governs the slope for values lower than the
threshold.
max_value: A `float` that sets the saturation threshold (the largest value
the function will return).
threshold: A `float` giving the threshold value of the activation function
below which values will be damped or set to zero.
Returns:
A `Tensor` representing the input tensor,
transformed by the relu activation function.
Tensor will be of the same shape and dtype of input `x`.
"""
return K.relu(x, alpha=alpha, max_value=max_value, threshold=threshold)
def tanh(x):
"""Hyperbolic tangent activation function.
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.tanh(a)
>>> b.numpy()
array([-0.9950547, -0.7615942, 0. , 0.7615942, 0.9950547],
dtype=float32)
Arguments:
x: Input tensor.
Returns:
Tensor of same shape and dtype of input `x`, with tanh activation:
`tanh(x) = sinh(x)/cosh(x) = ((exp(x) - exp(-x))/(exp(x) + exp(-x)))`.
"""
return nn.tanh(x)
def sigmoid(x):
"""Sigmoid activation function.
Applies the sigmoid activation function. The sigmoid function is defined as
1 divided by (1 + exp(-x)). It's curve is like an "S" and is like a smoothed
version of the Heaviside (Unit Step Function) function. For small values
(<-5) the sigmoid returns a value close to zero and for larger values (>5)
the result of the function gets close to 1.
Sigmoid is equivalent to a 2-element Softmax, where the second element is
assumed to be zero.
For example:
>>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
>>> b = tf.keras.activations.sigmoid(a)
>>> b.numpy() >= 0.0
array([ True, True, True, True, True])
Arguments:
x: Input tensor.
Returns:
Tensor with the sigmoid activation: `(1.0 / (1.0 + exp(-x)))`.
Tensor will be of same shape and dtype of input `x`.
"""
return nn.sigmoid(x)
def exponential(x):
"""Exponential activation function.
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.exponential(a)
>>> b.numpy()
array([ 0.04978707, 0.36787945, 1. , 2.7182817 , 20.085537 ],
dtype=float32)
Arguments:
x: Input tensor.
Returns:
Tensor with exponential activation: `exp(x)`. Tensor will be of same
shape and dtype of input `x`.
"""
return math_ops.exp(x)
def hard_sigmoid(x):
"""Hard sigmoid activation function.
Faster to compute than sigmoid activation.
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.hard_sigmoid(a)
>>> b.numpy()
array([0. , 0.3, 0.5, 0.7, 1. ], dtype=float32)
Arguments:
x: Input tensor.
Returns:
The hard sigmoid activation:
- `0` if `x < -2.5`
- `1` if `x > 2.5`
- `0.2 * x + 0.5` if `-2.5 <= x <= 2.5`.
"""
return K.hard_sigmoid(x)
def linear(x):
"""Linear activation function.
For example:
>>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
>>> b = tf.keras.activations.linear(a)
>>> b.numpy()
array([-3., -1., 0., 1., 3.], dtype=float32)
Arguments:
x: Input tensor.
Returns:
the input unmodified.
"""
return x
def serialize(activation):
"""Returns name attribute (`__name__`) of function.
Arguments:
activation : Function
Returns:
String denoting the name attribute of the input function
For example:
>>> tf.keras.activations.serialize(tf.keras.activations.tanh)
'tanh'
>>> tf.keras.activations.serialize(tf.keras.activations.sigmoid)
'sigmoid'
>>> tf.keras.activations.serialize('abcd')
Traceback (most recent call last):
...
ValueError: ('Cannot serialize', 'abcd')
Raises:
ValueError: The input function is not a valid one.
"""
if (hasattr(activation, '__name__') and
activation.__name__ in _TF_ACTIVATIONS_V2):
return _TF_ACTIVATIONS_V2[activation.__name__]
return serialize_keras_object(activation)
def deserialize(name, custom_objects=None):
"""Returns activation function denoted by input string.
Arguments:
x : String
Returns:
TensorFlow Activation function denoted by input string.
For example:
>>> tf.keras.activations.deserialize('linear')
<function linear at 0x1239596a8>
>>> tf.keras.activations.deserialize('sigmoid')
<function sigmoid at 0x123959510>
>>> tf.keras.activations.deserialize('abcd')
Traceback (most recent call last):
...
ValueError: Unknown activation function:abcd
Args:
name: The name of the activation function.
custom_objects: A {name:value} dictionary for activations not build into
keras.
Raises:
ValueError: `Unknown activation function` if the input string does not
denote any defined Tensorflow activation function.
"""
return deserialize_keras_object(
name,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='activation function')
def get(identifier):
"""Returns function.
Arguments:
identifier: Function or string
Returns:
Activation function denoted by input:
- `Linear activation function` if input is `None`.
- Function corresponding to the input string or input function.
For example:
>>> tf.keras.activations.get('softmax')
<function softmax at 0x1222a3d90>
>>> tf.keras.activations.get(tf.keras.activations.softmax)
<function softmax at 0x1222a3d90>
>>> tf.keras.activations.get(None)
<function linear at 0x1239596a8>
>>> tf.keras.activations.get(abs)
<built-in function abs>
>>> tf.keras.activations.get('abcd')
Traceback (most recent call last):
...
ValueError: Unknown activation function:abcd
Raises:
ValueError: Input is an unknown function or string, i.e., the input does
not denote any defined function.
"""
if identifier is None:
return linear
if isinstance(identifier, six.string_types):
identifier = str(identifier)
return deserialize(identifier)
elif callable(identifier):
return identifier
elif isinstance(identifier, dict):
return deserialize_keras_object(
identifier, printable_module_name='activation')
else:
raise TypeError(
'Could not interpret activation function identifier: {}'.format(
repr(identifier)))

File diff suppressed because it is too large Load Diff

View File

@ -1,140 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras backend config API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# The type of float to use throughout a session.
_FLOATX = 'float32'
# Epsilon fuzz factor used throughout the codebase.
_EPSILON = 1e-7
# Default image data format, one of "channels_last", "channels_first".
_IMAGE_DATA_FORMAT = 'channels_last'
def epsilon():
"""Returns the value of the fuzz factor used in numeric expressions.
Returns:
A float.
Example:
>>> tf.keras.backend.epsilon()
1e-07
"""
return _EPSILON
def set_epsilon(value):
"""Sets the value of the fuzz factor used in numeric expressions.
Arguments:
value: float. New value of epsilon.
Example:
>>> tf.keras.backend.epsilon()
1e-07
>>> tf.keras.backend.set_epsilon(1e-5)
>>> tf.keras.backend.epsilon()
1e-05
>>> tf.keras.backend.set_epsilon(1e-7)
"""
global _EPSILON
_EPSILON = value
def floatx():
"""Returns the default float type, as a string.
E.g. `'float16'`, `'float32'`, `'float64'`.
Returns:
String, the current default float type.
Example:
>>> tf.keras.backend.floatx()
'float32'
"""
return _FLOATX
def set_floatx(value):
"""Sets the default float type.
Note: It is not recommended to set this to float16 for training, as this will
likely cause numeric stability issues. Instead, mixed precision, which is
using a mix of float16 and float32, can be used by calling
`tf.keras.mixed_precision.experimental.set_policy('mixed_float16')`. See the
[mixed precision
guide](https://www.tensorflow.org/guide/keras/mixed_precision) for details.
Arguments:
value: String; `'float16'`, `'float32'`, or `'float64'`.
Example:
>>> tf.keras.backend.floatx()
'float32'
>>> tf.keras.backend.set_floatx('float64')
>>> tf.keras.backend.floatx()
'float64'
>>> tf.keras.backend.set_floatx('float32')
Raises:
ValueError: In case of invalid value.
"""
global _FLOATX
if value not in {'float16', 'float32', 'float64'}:
raise ValueError('Unknown floatx type: ' + str(value))
_FLOATX = str(value)
def image_data_format():
"""Returns the default image data format convention.
Returns:
A string, either `'channels_first'` or `'channels_last'`
Example:
>>> tf.keras.backend.image_data_format()
'channels_last'
"""
return _IMAGE_DATA_FORMAT
def set_image_data_format(data_format):
"""Sets the value of the image data format convention.
Arguments:
data_format: string. `'channels_first'` or `'channels_last'`.
Example:
>>> tf.keras.backend.image_data_format()
'channels_last'
>>> tf.keras.backend.set_image_data_format('channels_first')
>>> tf.keras.backend.image_data_format()
'channels_first'
>>> tf.keras.backend.set_image_data_format('channels_last')
Raises:
ValueError: In case of invalid `data_format` value.
"""
global _IMAGE_DATA_FORMAT
if data_format not in {'channels_last', 'channels_first'}:
raise ValueError('Unknown data_format: ' + str(data_format))
_IMAGE_DATA_FORMAT = str(data_format)

View File

@ -1,55 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for backend_config."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.frozen_keras import backend
from tensorflow.python.frozen_keras import backend_config
from tensorflow.python.keras import combinations
from tensorflow.python.platform import test
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class BackendConfigTest(test.TestCase):
def test_backend(self):
self.assertEqual(backend.backend(), 'tensorflow')
def test_epsilon(self):
epsilon = 1e-2
backend_config.set_epsilon(epsilon)
self.assertEqual(backend_config.epsilon(), epsilon)
backend_config.set_epsilon(1e-7)
self.assertEqual(backend_config.epsilon(), 1e-7)
def test_floatx(self):
floatx = 'float64'
backend_config.set_floatx(floatx)
self.assertEqual(backend_config.floatx(), floatx)
backend_config.set_floatx('float32')
self.assertEqual(backend_config.floatx(), 'float32')
def test_image_data_format(self):
image_data_format = 'channels_first'
backend_config.set_image_data_format(image_data_format)
self.assertEqual(backend_config.image_data_format(), image_data_format)
backend_config.set_image_data_format('channels_last')
self.assertEqual(backend_config.image_data_format(), 'channels_last')
if __name__ == '__main__':
test.main()

File diff suppressed because it is too large Load Diff

View File

@ -1,282 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
# pylint: disable=g-classes-have-attributes
"""Constraints: functions that impose constraints on weight values."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.framework import tensor_shape
from tensorflow.python.frozen_keras import backend as K
from tensorflow.python.frozen_keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.frozen_keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
class Constraint(object):
def __call__(self, w):
return w
def get_config(self):
return {}
class MaxNorm(Constraint):
"""MaxNorm weight constraint.
Constrains the weights incident to each hidden unit
to have a norm less than or equal to a desired value.
Arguments:
m: the maximum norm for the incoming weights.
axis: integer, axis along which to calculate weight norms.
For instance, in a `Dense` layer the weight matrix
has shape `(input_dim, output_dim)`,
set `axis` to `0` to constrain each weight vector
of length `(input_dim,)`.
In a `Conv2D` layer with `data_format="channels_last"`,
the weight tensor has shape
`(rows, cols, input_depth, output_depth)`,
set `axis` to `[0, 1, 2]`
to constrain the weights of each filter tensor of size
`(rows, cols, input_depth)`.
"""
def __init__(self, max_value=2, axis=0):
self.max_value = max_value
self.axis = axis
def __call__(self, w):
norms = K.sqrt(
math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
desired = K.clip(norms, 0, self.max_value)
return w * (desired / (K.epsilon() + norms))
def get_config(self):
return {'max_value': self.max_value, 'axis': self.axis}
class NonNeg(Constraint):
"""Constrains the weights to be non-negative.
"""
def __call__(self, w):
return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx())
class UnitNorm(Constraint):
"""Constrains the weights incident to each hidden unit to have unit norm.
Arguments:
axis: integer, axis along which to calculate weight norms.
For instance, in a `Dense` layer the weight matrix
has shape `(input_dim, output_dim)`,
set `axis` to `0` to constrain each weight vector
of length `(input_dim,)`.
In a `Conv2D` layer with `data_format="channels_last"`,
the weight tensor has shape
`(rows, cols, input_depth, output_depth)`,
set `axis` to `[0, 1, 2]`
to constrain the weights of each filter tensor of size
`(rows, cols, input_depth)`.
"""
def __init__(self, axis=0):
self.axis = axis
def __call__(self, w):
return w / (
K.epsilon() + K.sqrt(
math_ops.reduce_sum(
math_ops.square(w), axis=self.axis, keepdims=True)))
def get_config(self):
return {'axis': self.axis}
class MinMaxNorm(Constraint):
"""MinMaxNorm weight constraint.
Constrains the weights incident to each hidden unit
to have the norm between a lower bound and an upper bound.
Arguments:
min_value: the minimum norm for the incoming weights.
max_value: the maximum norm for the incoming weights.
rate: rate for enforcing the constraint: weights will be
rescaled to yield
`(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
Effectively, this means that rate=1.0 stands for strict
enforcement of the constraint, while rate<1.0 means that
weights will be rescaled at each step to slowly move
towards a value inside the desired interval.
axis: integer, axis along which to calculate weight norms.
For instance, in a `Dense` layer the weight matrix
has shape `(input_dim, output_dim)`,
set `axis` to `0` to constrain each weight vector
of length `(input_dim,)`.
In a `Conv2D` layer with `data_format="channels_last"`,
the weight tensor has shape
`(rows, cols, input_depth, output_depth)`,
set `axis` to `[0, 1, 2]`
to constrain the weights of each filter tensor of size
`(rows, cols, input_depth)`.
"""
def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
self.min_value = min_value
self.max_value = max_value
self.rate = rate
self.axis = axis
def __call__(self, w):
norms = K.sqrt(
math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
desired = (
self.rate * K.clip(norms, self.min_value, self.max_value) +
(1 - self.rate) * norms)
return w * (desired / (K.epsilon() + norms))
def get_config(self):
return {
'min_value': self.min_value,
'max_value': self.max_value,
'rate': self.rate,
'axis': self.axis
}
class RadialConstraint(Constraint):
"""Constrains `Conv2D` kernel weights to be the same for each radius.
For example, the desired output for the following 4-by-4 kernel::
```
kernel = [[v_00, v_01, v_02, v_03],
[v_10, v_11, v_12, v_13],
[v_20, v_21, v_22, v_23],
[v_30, v_31, v_32, v_33]]
```
is this::
```
kernel = [[v_11, v_11, v_11, v_11],
[v_11, v_33, v_33, v_11],
[v_11, v_33, v_33, v_11],
[v_11, v_11, v_11, v_11]]
```
This constraint can be applied to any `Conv2D` layer version, including
`Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
`"channels_first"` data format. The method assumes the weight tensor is of
shape `(rows, cols, input_depth, output_depth)`.
"""
def __call__(self, w):
w_shape = w.shape
if w_shape.rank is None or w_shape.rank != 4:
raise ValueError(
'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
height, width, channels, kernels = w_shape
w = K.reshape(w, (height, width, channels * kernels))
# TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch
# is supported.
w = K.map_fn(
self._kernel_constraint,
K.stack(array_ops.unstack(w, axis=-1), axis=0))
return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1),
(height, width, channels, kernels))
def _kernel_constraint(self, kernel):
"""Radially constraints a kernel with shape (height, width, channels)."""
padding = K.constant([[1, 1], [1, 1]], dtype='int32')
kernel_shape = K.shape(kernel)[0]
start = K.cast(kernel_shape / 2, 'int32')
kernel_new = K.switch(
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
lambda: kernel[start - 1:start, start - 1:start],
lambda: kernel[start - 1:start, start - 1:start] + K.zeros( # pylint: disable=g-long-lambda
(2, 2), dtype=kernel.dtype))
index = K.switch(
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
lambda: K.constant(0, dtype='int32'),
lambda: K.constant(1, dtype='int32'))
while_condition = lambda index, *args: K.less(index, start)
def body_fn(i, array):
return i + 1, array_ops.pad(
array,
padding,
constant_values=kernel[start + i, start + i])
_, kernel_new = control_flow_ops.while_loop(
while_condition,
body_fn,
[index, kernel_new],
shape_invariants=[index.get_shape(),
tensor_shape.TensorShape([None, None])])
return kernel_new
# Aliases.
max_norm = MaxNorm
non_neg = NonNeg
unit_norm = UnitNorm
min_max_norm = MinMaxNorm
radial_constraint = RadialConstraint
# Legacy aliases.
maxnorm = max_norm
nonneg = non_neg
unitnorm = unit_norm
def serialize(constraint):
return serialize_keras_object(constraint)
def deserialize(config, custom_objects=None):
return deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='constraint')
def get(identifier):
if identifier is None:
return None
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, six.string_types):
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret constraint identifier: ' +
str(identifier))

View File

@ -1,151 +0,0 @@
load("//tensorflow:tensorflow.bzl", "tf_py_test")
package(
default_visibility = ["//tensorflow:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
#TODO(scottzhu): Cleanup all the deps to python/keras
py_library(
name = "legacy_base_layer",
srcs = ["legacy_base_layer.py"],
deps = [
":base_layer_utils",
":input_spec",
":node",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:auto_control_deps",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:func_graph",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/autograph/core",
"//tensorflow/python/autograph/impl",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:execute",
"//tensorflow/python/eager:function",
"//tensorflow/python/frozen_keras:backend",
"//tensorflow/python/frozen_keras:constraint",
"//tensorflow/python/frozen_keras:initializers",
"//tensorflow/python/frozen_keras:regularizers",
"//tensorflow/python/frozen_keras/utils:generic_utils",
"//tensorflow/python/frozen_keras/utils:layer_utils",
"//tensorflow/python/frozen_keras/utils:tf_utils",
"//tensorflow/python/keras:metrics",
"//tensorflow/python/module",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/training/tracking",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/training/tracking:data_structures",
"//tensorflow/python/training/tracking:layer_utils",
"//tensorflow/tools/docs:doc_controls",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_library(
name = "base_layer_utils",
srcs = ["base_layer_utils.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:control_flow_v2_func_graphs",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:tf2",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/frozen_keras:backend",
"//tensorflow/python/training/tracking:base",
],
)
py_library(
name = "input_spec",
srcs = ["input_spec.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/frozen_keras:backend",
"@six_archive//:six",
],
)
py_library(
name = "node",
srcs = ["node.py"],
deps = [
":base_layer_utils",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python/frozen_keras:backend",
],
)
tf_py_test(
name = "legacy_base_layer_test",
size = "medium",
srcs = ["legacy_base_layer_test.py"],
python_version = "PY3",
shard_count = 8,
tags = [
"no_rocm",
"nomac", # TODO(mihaimaruseac): b/127695564
],
deps = [
":legacy_base_layer",
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "base_layer_utils_test",
srcs = ["base_layer_utils_test.py"],
python_version = "PY3",
tags = [
"nomac", # TODO(mihaimaruseac): b/127695564
],
deps = [
":base_layer_utils",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "input_spec_test",
size = "small",
srcs = ["input_spec_test.py"],
python_version = "PY3",
tags = [
"nomac", # TODO(mihaimaruseac): b/127695564
],
deps = [
":input_spec",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -1,781 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains private utilities used mainly by the base Layer class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.python import tf2
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.frozen_keras import backend
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_v2_func_graphs
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.tracking import base as tracking
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
_call_context = threading.local()
def make_variable(name,
shape=None,
dtype=dtypes.float32,
initializer=None,
trainable=None,
caching_device=None,
validate_shape=True,
constraint=None,
use_resource=None,
collections=None,
synchronization=tf_variables.VariableSynchronization.AUTO,
aggregation=tf_variables.VariableAggregation.NONE,
partitioner=None): # pylint: disable=unused-argument
"""Temporary util to create a variable (relies on `variable_scope.variable`).
Some reuse-related technicalities prevent us from using
`variable_scope.get_variable()` directly, so we use a subcomponent
that has fewer constraints (`variable_scope.variable()`).
In the longer term, it seems like a similar "default variable creator" method
should exist in `Trackable` instead. When this happens, we can get
rid of this temporary solution.
TODO(fchollet): remove this method when no longer needed.
Arguments:
name: Variable name.
shape: Variable shape.
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
initializer: Initializer instance (callable).
trainable: Whether the variable should be part of the layer's
"trainable_variables" (e.g. variables, biases)
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
marked as non-trainable. `trainable` defaults to `True` unless
`synchronization` is set to `ON_READ`.
caching_device: Passed to `tf.Variable`.
validate_shape: Passed to `tf.Variable`.
constraint: Constraint instance (callable).
use_resource: Whether to use a `ResourceVariable`.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set to
`AUTO` and the current `DistributionStrategy` chooses
when to synchronize. If `synchronization` is set to `ON_READ`,
`trainable` must not be set to `True`.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
`tf.VariableAggregation`.
partitioner: Not handled at this time.
Returns:
Variable instance.
"""
initializing_from_value = False
if initializer is not None and not callable(initializer):
initializing_from_value = True
if initializing_from_value:
init_val = initializer
variable_dtype = None
else:
# Instantiate initializer if provided initializer is a type object.
if isinstance(
initializer,
(type(init_ops.Initializer), type(init_ops_v2.Initializer))):
initializer = initializer()
init_val = lambda: initializer(shape, dtype=dtype)
variable_dtype = dtype.base_dtype
if use_resource is None:
use_resource = True
# TODO(apassos,rohanj) figure out how to remove collections from here so we
# can remove the V1.
variable_shape = tensor_shape.TensorShape(shape)
return tf_variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
caching_device=caching_device,
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
use_resource=use_resource,
collections=collections,
synchronization=synchronization,
aggregation=aggregation,
shape=variable_shape if variable_shape else None)
def collect_previous_mask(input_tensors):
"""Retrieves the output mask(s) of the previous node.
Arguments:
input_tensors: An arbitrary structure of Tensors.
Returns:
A mask tensor or list of mask tensors.
"""
def _collect_previous_mask(x):
return getattr(x, '_keras_mask', None)
return nest.map_structure(_collect_previous_mask, input_tensors)
def have_all_keras_metadata(tensors):
return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
def generate_placeholders_from_shape(shape):
return array_ops.placeholder(shape=shape, dtype=backend.floatx())
def create_keras_history(tensors):
"""Wraps TensorFlow Operations for compatibility with the Functional API.
This method checks to see if a Tensor in `tensors` is missing Keras metadata
and has its origin in a Keras `Input` Layer. If so, this method will replace
the raw TensorFlow Operations that created this tensor with
`TensorFlowOpLayer` instances that create identical operations.
Any Tensors not originating from a Keras `Input` Layer will be treated as
constants when constructing `TensorFlowOpLayer` instances.
Arguments:
tensors: A structure of Tensors, some of which come from raw TensorFlow
operations and need to have Keras metadata assigned to them.
Returns:
created_layers: List. The `TensorFlowOpLayer` instances created to wrap
the raw Tensorflow operations.
"""
_, created_layers = _create_keras_history_helper(tensors, set(), [])
return created_layers
def _create_keras_history_helper(tensors, processed_ops, created_layers):
"""Helper method for `create_keras_history`.
Arguments:
tensors: A structure of Tensors for which to create Keras metadata.
processed_ops: Set. TensorFlow operations that have already been wrapped in
`TensorFlowOpLayer` instances.
created_layers: List. The `TensorFlowOpLayer` instances created.
Returns:
Tuple. First element is the updated set of TensorFlow Operations that
have been wrapped in `TensorFlowOpLayer` instances. Second element is
a list of the `TensorFlowOpLayer` instances created.
"""
# Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
# Cannot be imported at top because of circular dependencies.
# TODO(omalleyt): Resolve circular dependency.
from tensorflow.python.frozen_keras.engine import legacy_base_layer as base_layer # pylint: disable=g-import-not-at-top
tensor_list = nest.flatten(tensors)
for tensor in tensor_list:
if getattr(tensor, '_keras_history', None) is not None:
continue
op = tensor.op # The Op that created this Tensor.
if op not in processed_ops:
if op.type.startswith('Sparse'):
lambda_example = """
weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
output = tf.keras.layers.Lambda(weights_mult)(input)
"""
raise ValueError(
'Sparse ops are not supported with functional models with built-in '
'layer wrapping. Please wrap the sparse ops in a Lambda layer like'
': \n{lambda_example}\n'.format(lambda_example=lambda_example))
# Recursively set `_keras_history`.
op_inputs = list(op.inputs)
constants = {}
layer_inputs = []
for i, op_input in enumerate(op_inputs):
if uses_keras_history(op_input):
layer_inputs.append(op_input)
else:
# Treat any value not originating from a `keras.Input` as
# a constant. Variables cannot be supported.
ds_with_session = (
distribution_strategy_context.in_cross_replica_context() and
not ops.executing_eagerly_outside_functions())
using_xla = control_flow_util.GraphOrParentsInXlaContext(
ops.get_default_graph())
if ds_with_session or using_xla:
# In Legacy Graph mode, evaluating here makes Session be
# configured improperly. The downside of this is that saving
# via `get_config` breaks, but SavedModel still works.
constants[i] = op_input
else:
with ops.init_scope():
constants[i] = backend.function([], op_input)([])
layer_inputs = unnest_if_single_tensor(layer_inputs)
processed_ops, created_layers = _create_keras_history_helper(
layer_inputs, processed_ops, created_layers)
name = op.name
node_def = op.node_def.SerializeToString()
op_layer = base_layer.TensorFlowOpLayer(
node_def, constants=constants, name=name)
created_layers.append(op_layer)
op_layer._add_inbound_node( # pylint: disable=protected-access
layer_inputs, op.outputs)
processed_ops.update([op])
return processed_ops, created_layers
def unnest_if_single_tensor(input_tensors):
# Preserve compatibility with older configs
flat_input_tensors = nest.flatten(input_tensors)
# If this is a single element but not a dict, unwrap. If this is a dict,
# assume the first layer expects a dict (as is the case with a
# DenseFeatures layer); pass through.
if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
input_tensors = flat_input_tensors[0]
return input_tensors
def needs_keras_history(tensors, ignore_call_context=False):
"""Check if any Tensors need to be wrapped in TensorFlowOpLayers.
This will never return True inside a sublayer, because sublayers
do not need to create Keras History. Otherwise, this returns True
if one or more of `tensors` originates from a `keras.Input` and
does not have `_keras_history` set.
Arguments:
tensors: An arbitrary nested structure of Tensors.
ignore_call_context: Whether to ignore the check of if currently
outside of a `call` context. This is `True` when creating
KerasHistory inside `Node`, where we always know that Tensors
are being used with the Functional API.
Returns:
Bool, whether at least one Tensor needs to be wrapped.
"""
input_tensors = nest.flatten(tensors)
if call_context().in_call and not ignore_call_context:
return False
if all(
getattr(tensor, '_keras_history', None) is not None
for tensor in input_tensors):
# KerasHistory already set.
return False
return uses_keras_history(tensors)
def is_in_keras_graph():
"""Returns if currently executing inside of a Keras graph."""
return call_context().in_keras_graph
def is_in_eager_or_tf_function():
"""Returns if in eager mode or inside of a tf.function."""
return context.executing_eagerly() or is_in_tf_function()
def is_in_tf_function():
"""Returns if inside of a tf.function."""
# Check if running in V1 graph mode.
if not ops.executing_eagerly_outside_functions():
return False
if not ops.inside_function():
return False
# Check if inside Keras FuncGraph.
if is_in_keras_graph():
return False
# Check for a v1 `wrap_function` FuncGraph.
graph = ops.get_default_graph()
if (getattr(graph, 'name', False) and
graph.name.startswith('wrapped_function')):
return False
return True
def uses_keras_history(tensors):
"""Check if at least one Tensor originates from a `keras.Input`.
This is `True` if at least one Tensor has its origin in a `keras.Input`.
Any Tensor that originates from a `keras.Input` will have a dependency
Tensor with a `_keras_history` attribute attached. Tensors that have
already been checked to not originate from a `keras.Input`
are marked as `_keras_history_checked`.
Arguments:
tensors: An arbitrary nested structure of Tensors.
Returns:
Bool, whether at least one Tensor originates from a `keras.Input`.
"""
checked_tensors = set()
tensors_to_check = nest.flatten(tensors)
while tensors_to_check:
new_tensors_to_check = []
for tensor in tensors_to_check:
if id(tensor) in checked_tensors:
continue
checked_tensors.add(id(tensor))
if getattr(tensor, '_keras_history_checked', None) is not None:
continue
if getattr(tensor, '_keras_history', None) is not None:
return True
try:
new_tensors_to_check.extend(tensor.op.inputs)
except AttributeError:
# In case `tensor` is a Variable created in an Eager context.
pass
tensors_to_check = new_tensors_to_check
# Mark that these Tensors have been checked once for `_keras_history`,
# and should not be checked again for performance reasons.
mark_checked(tensors)
return False
def mark_checked(tensors):
"""Marks that these Tensors should not be tracked.
This prevents Layers from attempting to create TensorFlowOpLayers
for these Tensors.
Arguments:
tensors: An arbitrary structure of Tensors.
"""
def _mark_checked(tensor):
tensor._keras_history_checked = True # pylint: disable=protected-access
nest.map_structure(_mark_checked, tensors)
def call_context():
"""Returns currently active `CallContext`."""
if getattr(_call_context, 'call_context', None) is None:
_call_context.call_context = CallContext()
return _call_context.call_context
class CallContext(object):
"""Keeps track of properties currently inside a Layer/Model's `call`.
Attributes:
layer: The `Layer` whose `call` is currently active.
inputs: The inputs to the currently active `Layer`.
frozen: Whether currently executing inside a `Layer` with `trainable` set to
`False`.
in_call: Whether currently inside the `call` of a Layer.
training: Whether currently executing in training or inference mode.
in_keras_graph: Whether executing inside the Keras Graph.
saving: Whether currently saving to SavedModel.
"""
def __init__(self):
self.layer = None
self.inputs = None
self.frozen = False
self.in_call = False
self.training = None
self._in_keras_graph = False
self.saving = False
@tf_contextlib.contextmanager
def enter(self, layer, inputs, build_graph, training, saving=None):
"""Push a Layer and its inputs and state onto the current call context."""
prev_layer = self.layer
prev_inputs = self.inputs
prev_frozen = self.frozen
prev_in_call = self.in_call
prev_training = self.training
prev_in_keras_graph = self._in_keras_graph
prev_saving = self.saving
self.layer = layer
self.inputs = inputs
self.frozen = self.frozen or not layer.trainable
self.in_call = True
self.training = training
self._in_keras_graph = (
self._in_keras_graph or
(build_graph and
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
self.saving = prev_saving if saving is None else saving
try:
yield
finally:
self.layer = prev_layer
self.inputs = prev_inputs
self.frozen = prev_frozen
self.in_call = prev_in_call
self.training = prev_training
self._in_keras_graph = prev_in_keras_graph
self.saving = prev_saving
@property
def in_keras_graph(self):
# Returns True even if in a subgraph of the Keras graph, such as those
# created by control flow ops.
if context.executing_eagerly():
return False
return (self._in_keras_graph or
getattr(backend.get_graph(), 'name', None) == 'keras_graph')
def training_arg_passed_to_call(argspec, args, kwargs):
"""Returns whether a user passed the `training` argument in `__call__`."""
# `argspec.args` starts with ['self', 'inputs']
full_args = dict(zip(argspec.args[2:], args))
full_args.update(kwargs)
return 'training' in full_args and full_args['training'] is not None
def autocast_context_manager(dtype):
"""Returns a context manager to autocast AutoCastVariables.
Under this context manager, AutoCastVariables will be casted to `dtype` if
`dtype` is floating-point. Otherwise, AutoCastVariables will not be casted.
Args:
dtype: The dtype to cast AutoCastVariables to, or None.
Returns:
A context manager to automatically cast AutoCastVariables.
"""
if dtype and not dtypes.as_dtype(dtype).is_floating:
dtype = None
return ops.get_default_graph()._enable_auto_casting_variables(dtype) # pylint: disable=protected-access
def is_subclassed(layer):
"""Returns True if the object is a subclassed layer or subclassed model."""
return (layer.__module__.find('keras.engine') == -1 and
layer.__module__.find('keras.layers') == -1)
def from_saved_model(layer):
"""Returns whether the layer is loaded from a SavedModel."""
return layer.__module__.find('keras.saving.saved_model') != -1
def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
"""Checks that tensors passed to `add_*` method match the Keras graph.
When one of the `add_*` method is called inside a V2 conditional branch,
the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
We need to raise clear error messages in such cases.
Arguments:
tensor: Tensor to check, or `False` if it is known that an error
should be raised.
method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
force_raise: If an error should be raised regardless of `tensor`.
Raises:
RuntimeError: In case of an out-of-graph tensor.
"""
if (force_raise or
(ops.executing_eagerly_outside_functions() and
hasattr(tensor, 'graph') and
isinstance(tensor.graph,
(control_flow_v2_func_graphs.CondBranchFuncGraph,
control_flow_v2_func_graphs.WhileCondFuncGraph,
control_flow_v2_func_graphs.WhileBodyFuncGraph)))):
if method == 'activity_regularizer':
bad_example = """
class TestModel(tf.keras.Model):
def __init__(self):
super(TestModel, self).__init__(name='test_model')
self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
def call(self, x, training=None):
if training:
return self.dense(x)
else:
return self.dense(x)
"""
correct_example = """
class TestModel(tf.keras.Model):
def __init__(self):
super(TestModel, self).__init__(name='test_model')
self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')
def call(self, x, training=None):
return self.dense(x)
"""
raise RuntimeError(
'You are using a layer with `activity_regularizer` in a control flow '
'branch, e.g.:\n{bad_example}\nThis is currently not supported. '
'Please move your call to the layer with `activity_regularizer` out '
'of the control flow branch, e.g.:\n{correct_example}\n'
'You can also resolve this by marking your outer model/layer dynamic'
' (eager-only) by passing `dynamic=True` to the layer constructor. '
'Any kind of control flow is supported with dynamic layers. '
'Note that using `dynamic=True` requires you to implement static '
'shape inference in the `compute_output_shape(input_shape)` '
'method.'.format(
bad_example=bad_example, correct_example=correct_example))
if method == 'add_metric':
bad_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
metric = compute_metric(inputs)
else:
metric = 0.
self.add_metric(metric, name='my_metric', aggregation='mean')
return inputs
"""
elif method == 'add_loss':
bad_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
self.add_loss(loss)
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
loss = compute_loss(inputs)
else:
loss = 0.
self.add_loss(loss)
return inputs
"""
else:
bad_example = """
def call(self, inputs, training=None):
if training:
self.add_update(self.w.assign_add(1))
return inputs
"""
correct_example = """
def call(self, inputs, training=None):
if training:
increment = 1
else:
increment = 0
self.add_update(self.w.assign_add(increment))
return inputs
"""
raise RuntimeError(
'You are using the method `{method}` in a control flow branch '
'in your layer, e.g.:\n{bad_example}\n'
'This is not currently supported. '
'Please move your call to {method} out of the control flow branch, '
'e.g.:\n{correct_example}\n'
'You can also resolve this by marking your layer '
'as dynamic (eager-only) by passing '
'`dynamic=True` to the layer constructor. '
'Any kind of control flow is supported with dynamic layers. '
'Note that using `dynamic=True` requires you '
'to implement static shape inference '
'in the `compute_output_shape(input_shape)` method.'.format(
method=method,
bad_example=bad_example,
correct_example=correct_example))
def mark_as_return(outputs, acd):
"""Marks `outputs` as the return values for automatic control deps."""
def _mark_as_return(tensor):
"""Marks `tensor` as the return value for automatic control deps."""
if not tensor_util.is_tensor(tensor):
return tensor
# pylint: disable=protected-access
return_tensor = acd.mark_as_return(tensor)
if getattr(tensor, '_keras_mask', None) is not None:
return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
else:
return_tensor._keras_mask = None
# Handle TensorFlow Probability attached metadata.
# TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
if getattr(tensor, '_tfp_distribution', None) is not None:
return_tensor._tfp_distribution = tensor._tfp_distribution
return return_tensor
# pylint: enable=protected-access
return nest.map_structure(_mark_as_return, outputs)
V2_DTYPE_BEHAVIOR = None
# These two functions are not exported because we plan on removing them in the
# future.
def enable_v2_dtype_behavior():
"""Enable the V2 dtype behavior for Keras layers.
By default, the V2 dtype behavior is enabled in TensorFlow 2.
When enabled, the dtype of Keras layers defaults to floatx (which is typically
float32) instead of None. In addition, layers will automatically cast
floating-point inputs to the layer's dtype.
For example, once enabled, the following block will run a Conv2D layer
in float32:
```python
x = tf.ones((4, 4, 4, 4), dtype='float64')
layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
print(layer.dtype) # Float32 when enabled. None when disabled.
# When enabled, will cast inputs to the layer's dtype, which is float32. When
# disabled, will do no casting, so the layer is done in float64.
y = layer(x)
```
A layer author can opt-out their layer from the automatic input casting by
passing `autocast=False` to the base Layer's constructor. This disables the
autocasting part of the V2 behavior for that layer, but not the defaulting to
floatx part of the V2 behavior.
When a global `tf.keras.mixed_precision.experimental.Policy` is set, the
layer's dtype will default to the global policy instead of floatx. Layers
will automatically cast inputs to the policy's compute_dtype.
"""
global V2_DTYPE_BEHAVIOR
V2_DTYPE_BEHAVIOR = True
def disable_v2_dtype_behavior():
"""Disables the V2 dtype behavior for Keras layers.
See `enable_v2_dtype_behavior`.
This function will be removed in the future.
"""
global V2_DTYPE_BEHAVIOR
V2_DTYPE_BEHAVIOR = False
def v2_dtype_behavior_enabled():
"""Returns True if the V2 dtype behavior is enabled."""
if V2_DTYPE_BEHAVIOR is None:
return tf2.enabled()
return V2_DTYPE_BEHAVIOR
class TrackableWeightHandler(object):
"""Keras wrapper for handling tracking.Trackable object saving and restoring.
This class handles Trackables in both V1 and V2 modes, ensuring that they can
be saved and restored with the correct data and without adding additional ops
on every save.
Attributes:
trackable: The trackable to wrap.
num_tensors: The number of tensors that this trackable requires for saving.
"""
def __init__(self, trackable):
if not isinstance(trackable, tracking.Trackable):
raise ValueError('%s is not a Trackable object.' % (trackable,))
self._trackable = trackable
# TODO(b/141682913): Figure out why this is private and fix it.
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
if len(saveables) != 1:
raise ValueError('Only Trackables with one Saveable are supported.')
saveable = list(saveables)[0]
if ops.executing_eagerly_outside_functions():
# If we're in eager mode, we need to defer calling the Trackable's
# saveable() callable until data export time.
# However, it is safe to call the saveable as many times as we want, so
# we will call it now to figure out how many tensors this Trackable will
# produce.
self._saveable = saveable
self._num_tensors = len(self._saveable().specs)
self._setter = lambda weights: self._saveable().restore(weights, None)
self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
else:
# If we're in Graph mode, we need to evaluate the Saveable only once and
# cache the resulting restore graph. Failing to do this will result in
# new assignment ops being added to the graph each time set_weights() is
# called.
self._placeholder_tensors = []
self._saveable = saveable()
self._num_tensors = len(self._saveable.specs)
for spec in self._saveable.specs:
tensor = spec.tensor
self._placeholder_tensors.append(
array_ops.placeholder(tensor.dtype, tensor.shape))
self._assign_op = self._saveable.restore(self._placeholder_tensors, None)
self._setter = self._set_weights_v1
self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
@property
def num_tensors(self):
return self._num_tensors
def set_weights(self, weights):
if len(weights) != self._num_tensors:
raise ValueError(
('Weight handler for trackable %s received the wrong number of ' +
'weights: expected %s, got %s.') %
(self._trackable, self._num_tensors, len(weights)))
self._setter(weights)
def get_tensors(self):
return self._getter()
def _set_weights_v1(self, weights):
feed_dict = {}
for idx, tensor in enumerate(weights):
feed_dict[self._placeholder_tensors[idx]] = tensor
backend.get_session().run(self._assign_op, feed_dict)
# TODO(kathywu): This is a temporary hack. When a network of layers is revived
# from SavedModel, only the top-level layer will have losses. This causes issues
# in eager mode because the child layers may have graph losses
# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
# whenever eager losses are added to one layer, add eager losses to all
# child layers. This causes `.losses` to only return eager losses.
REVIVED_LOSS_PLACEHOLDER = (
'This layer\'s losses have been added to the parent layer.')

View File

@ -1,71 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.frozen_keras import backend
from tensorflow.python.frozen_keras.engine import base_layer_utils
from tensorflow.python.ops import lookup_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class TrackableWeightHandlerTest(test.TestCase, parameterized.TestCase):
def get_table_handler(self):
# Note: There is some repetition in these tests' setup. However, Tensorflow
# does not play nicely with a separate setUp() call (causing errors related
# to graph building), so we have to use a called setup instead of a setUp()
# call.
table = lookup_ops.MutableHashTable(
key_dtype=dtypes.string, value_dtype=dtypes.int32, default_value=0)
return base_layer_utils.TrackableWeightHandler(table)
def test_get_num_tensors(self):
table_handler = self.get_table_handler()
self.assertEqual(2, table_handler.num_tensors)
def test_get_and_set_weights(self):
table_handler = self.get_table_handler()
table_data = {b"a": 1, b"b": 2, b"c": 3}
table_handler.set_weights(
[list(table_data.keys()),
list(table_data.values())])
weights = backend.batch_get_value(table_handler.get_tensors())
weight_data = {key: value for key, value in zip(weights[0], weights[1])}
self.assertDictEqual(table_data, weight_data)
def test_get_and_set_weights_does_not_add_ops(self):
table_handler = self.get_table_handler()
table_data = {b"a": 1, b"b": 2, b"c": 3}
table_handler.set_weights(
[list(table_data.keys()),
list(table_data.values())])
_ = backend.batch_get_value(table_handler.get_tensors())
backend.get_session().graph.finalize()
table_handler.set_weights(
[list(table_data.keys()),
list(table_data.values())])
_ = backend.batch_get_value(table_handler.get_tensors())
if __name__ == "__main__":
test.main()

View File

@ -1,233 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
# pylint: disable=g-classes-have-attributes
"""Contains the InputSpec class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.frozen_keras import backend
from tensorflow.python.util import nest
class InputSpec(object):
"""Specifies the rank, dtype and shape of every input to a layer.
Layers can expose (if appropriate) an `input_spec` attribute:
an instance of `InputSpec`, or a nested structure of `InputSpec` instances
(one per input tensor). These objects enable the layer to run input
compatibility checks for input structure, input rank, input shape, and
input dtype.
A None entry in a shape is compatible with any dimension,
a None shape is compatible with any shape.
Arguments:
dtype: Expected DataType of the input.
shape: Shape tuple, expected shape of the input
(may include None for unchecked axes).
ndim: Integer, expected rank of the input.
max_ndim: Integer, maximum rank of the input.
min_ndim: Integer, minimum rank of the input.
axes: Dictionary mapping integer axes to
a specific dimension value.
"""
def __init__(self,
dtype=None,
shape=None,
ndim=None,
max_ndim=None,
min_ndim=None,
axes=None):
self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
if shape is not None:
self.ndim = len(shape)
self.shape = shape
else:
self.ndim = ndim
self.shape = None
self.max_ndim = max_ndim
self.min_ndim = min_ndim
try:
axes = axes or {}
self.axes = {int(k): axes[k] for k in axes}
except (ValueError, TypeError):
raise TypeError('The keys in axes must be integers.')
if self.axes and (self.ndim is not None or self.max_ndim is not None):
max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
max_axis = max(self.axes)
if max_axis > max_dim:
raise ValueError('Axis {} is greater than the maximum allowed value: {}'
.format(max_axis, max_dim))
def __repr__(self):
spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
('shape=' + str(self.shape)) if self.shape else '',
('ndim=' + str(self.ndim)) if self.ndim else '',
('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
('axes=' + str(self.axes)) if self.axes else '']
return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
def get_config(self):
return {
'dtype': self.dtype,
'shape': self.shape,
'ndim': self.ndim,
'max_ndim': self.max_ndim,
'min_ndim': self.min_ndim,
'axes': self.axes}
@classmethod
def from_config(cls, config):
return cls(**config)
def to_tensor_shape(spec):
"""Returns a tf.TensorShape object that matches the shape specifications.
If the InputSpec's shape or ndim is defined, this method will return a fully
or partially-known shape. Otherwise, the returned TensorShape is None.
Args:
spec: an InputSpec object.
Returns:
a tf.TensorShape object
"""
if spec.ndim is None and spec.shape is None:
return tensor_shape.TensorShape(None)
elif spec.shape is not None:
return tensor_shape.TensorShape(spec.shape)
else:
shape = [None] * spec.ndim
for a in spec.axes:
shape[a] = spec.axes[a] # Assume that axes is defined
return tensor_shape.TensorShape(shape)
def assert_input_compatibility(input_spec, inputs, layer_name):
"""Checks compatibility between the layer and provided inputs.
This checks that the tensor(s) `inputs` verify the input assumptions
of a layer (if any). If not, a clear and actional exception gets raised.
Arguments:
input_spec: An InputSpec instance, list of InputSpec instances, a nested
structure of InputSpec instances, or None.
inputs: Input tensor, list of input tensors, or a nested structure of
input tensors.
layer_name: String, name of the layer (for error message formatting).
Raises:
ValueError: in case of mismatch between
the provided inputs and the expectations of the layer.
"""
if not input_spec:
return
inputs = nest.flatten(inputs)
input_spec = nest.flatten(input_spec)
if len(inputs) != len(input_spec):
raise ValueError('Layer ' + layer_name + ' expects ' +
str(len(input_spec)) + ' inputs, '
'but it received ' + str(len(inputs)) +
' input tensors. Inputs received: ' + str(inputs))
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
if spec is None:
continue
if (spec.ndim is not None or
spec.min_ndim is not None or
spec.max_ndim is not None):
if x.shape.ndims is None:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
layer_name + ' is incompatible with the layer: '
'its rank is undefined, but the layer requires a '
'defined rank.')
# Check ndim.
if spec.ndim is not None:
ndim = x.shape.ndims
if ndim != spec.ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
layer_name + ' is incompatible with the layer: '
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
str(ndim) + '. Full shape received: ' +
str(x.shape.as_list()))
if spec.max_ndim is not None:
ndim = x.shape.ndims
if ndim is not None and ndim > spec.max_ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
layer_name + ' is incompatible with the layer: '
'expected max_ndim=' + str(spec.max_ndim) +
', found ndim=' + str(ndim))
if spec.min_ndim is not None:
ndim = x.shape.ndims
if ndim is not None and ndim < spec.min_ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
layer_name + ' is incompatible with the layer: '
': expected min_ndim=' + str(spec.min_ndim) +
', found ndim=' + str(ndim) +
'. Full shape received: ' +
str(x.shape.as_list()))
# Check dtype.
if spec.dtype is not None:
if x.dtype != spec.dtype:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
layer_name + ' is incompatible with the layer: '
'expected dtype=' + str(spec.dtype) +
', found dtype=' + str(x.dtype))
# Check specific shape axes.
if spec.axes:
shape = x.shape.as_list()
if shape is not None:
for axis, value in spec.axes.items():
if hasattr(value, 'value'):
value = value.value
if value is not None and shape[int(axis)] not in {value, None}:
raise ValueError(
'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
' incompatible with the layer: expected axis ' + str(axis) +
' of input shape to have value ' + str(value) +
' but received input with shape ' + str(shape))
# Check shape.
if spec.shape is not None:
shape = x.shape.as_list()
if shape is not None:
for spec_dim, dim in zip(spec.shape, shape):
if spec_dim is not None and dim is not None:
if spec_dim != dim:
raise ValueError('Input ' + str(input_index) +
' is incompatible with layer ' + layer_name +
': expected shape=' + str(spec.shape) +
', found shape=' + str(shape))
def to_tensor_spec(input_spec, default_dtype=None):
"""Converts a Keras InputSpec object to a TensorSpec."""
default_dtype = default_dtype or backend.floatx()
if isinstance(input_spec, InputSpec):
dtype = input_spec.dtype or default_dtype
return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
return tensor_spec.TensorSpec(None, default_dtype)

View File

@ -1,66 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""InputSpec tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.frozen_keras.engine import input_spec
from tensorflow.python.platform import test
class InputSpecTest(test.TestCase):
def test_axes_initialization(self):
input_spec.InputSpec(shape=[1, None, 2, 3], axes={3: 5, '2': 2})
with self.assertRaisesRegexp(ValueError, 'Axis 4 is greater than'):
input_spec.InputSpec(shape=[1, None, 2, 3], axes={4: 5})
with self.assertRaisesRegexp(TypeError, 'keys in axes must be integers'):
input_spec.InputSpec(shape=[1, None, 2, 3], axes={'string': 5})
class InputSpecToTensorShapeTest(test.TestCase):
def test_defined_shape(self):
spec = input_spec.InputSpec(shape=[1, None, 2, 3])
self.assertAllEqual(
[1, None, 2, 3], input_spec.to_tensor_shape(spec).as_list())
def test_defined_ndims(self):
spec = input_spec.InputSpec(ndim=5)
self.assertAllEqual(
[None] * 5, input_spec.to_tensor_shape(spec).as_list())
spec = input_spec.InputSpec(ndim=0)
self.assertAllEqual(
[], input_spec.to_tensor_shape(spec).as_list())
spec = input_spec.InputSpec(ndim=3, axes={1: 3, -1: 2})
self.assertAllEqual(
[None, 3, 2], input_spec.to_tensor_shape(spec).as_list())
def test_undefined_shapes(self):
spec = input_spec.InputSpec(max_ndim=5)
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
input_spec.to_tensor_shape(spec).as_list()
spec = input_spec.InputSpec(min_ndim=5, max_ndim=5)
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
input_spec.to_tensor_shape(spec).as_list()
if __name__ == '__main__':
test.main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,190 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
# pylint: disable=g-classes-have-attributes
"""Contains the `Node` class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.frozen_keras import backend
from tensorflow.python.frozen_keras.engine import base_layer_utils
from tensorflow.python.util import nest
class Node(object):
"""A `Node` describes the connectivity between two layers.
Each time a layer is connected to some new input,
a node is added to `layer._inbound_nodes`.
Each time the output of a layer is used by another layer,
a node is added to `layer._outbound_nodes`.
Arguments:
outbound_layer: the layer that takes
`input_tensors` and turns them into `output_tensors`
(the node gets created when the `call`
method of the layer was called).
inbound_layers: a list of layers, the same length as `input_tensors`,
the layers from where `input_tensors` originate.
node_indices: a list of integers, the same length as `inbound_layers`.
`node_indices[i]` is the origin node of `input_tensors[i]`
(necessary since each inbound layer might have several nodes,
e.g. if the layer is being shared with a different data stream).
tensor_indices: a list of integers,
the same length as `inbound_layers`.
`tensor_indices[i]` is the index of `input_tensors[i]` within the
output of the inbound layer
(necessary since each inbound layer might
have multiple tensor outputs, with each one being
independently manipulable).
input_tensors: list of input tensors.
output_tensors: list of output tensors.
arguments: dictionary of keyword arguments that were passed to the
`call` method of the layer at the call that created the node.
`node_indices` and `tensor_indices` are basically fine-grained coordinates
describing the origin of the `input_tensors`.
A node from layer A to layer B is added to:
- A._outbound_nodes
- B._inbound_nodes
"""
def __init__(self,
outbound_layer,
inbound_layers,
node_indices,
tensor_indices,
input_tensors,
output_tensors,
arguments=None):
# Layer instance (NOT a sequence)
if isinstance(outbound_layer, (list, tuple, dict)):
raise ValueError('`outbound_layer` should be a layer instance, '
'not a list, tuple, or, dict.')
# this is the layer that takes a nested structure of input tensors
# and turns them into a nested structure of output tensors.
# the current node will be added to
# the inbound_nodes of outbound_layer.
self.outbound_layer = outbound_layer
# The following 3 properties describe where
# the input tensors come from: which layers,
# and for each layer, which node and which
# tensor output of each node.
# Nested structure of layer instances.
self.inbound_layers = inbound_layers
# Nested structure of integers, 1:1 mapping with inbound_layers.
self.node_indices = node_indices
# Nested of integers, 1:1 mapping with inbound_layers.
self.tensor_indices = tensor_indices
# Following 2 properties:
# tensor inputs and outputs of outbound_layer.
# Nested structure of tensors. 1:1 mapping with inbound_layers.
self.input_tensors = input_tensors
# Nested structure of tensors, created by outbound_layer.call().
self.output_tensors = output_tensors
# Following 2 properties: input and output shapes.
# Nested structure of shape tuples, shapes of input_tensors.
self.input_shapes = nest.map_structure(backend.int_shape, input_tensors)
# Nested structure of shape tuples, shapes of output_tensors.
self.output_shapes = nest.map_structure(backend.int_shape, output_tensors)
# Optional keyword arguments to layer's `call`.
self.arguments = arguments
# Create Keras History for any Keras Tensors in `arguments`.
tensor_arguments = [
t for t in nest.flatten(self.arguments) if isinstance(t, ops.Tensor)
]
for tensor_argument in tensor_arguments:
if base_layer_utils.needs_keras_history(
tensor_argument, ignore_call_context=True):
base_layer_utils.create_keras_history(tensor_argument)
# Add nodes to all layers involved.
for layer in nest.flatten(inbound_layers):
if layer is not None:
# For compatibility with external Keras, we use the deprecated
# accessor here.
layer.outbound_nodes.append(self)
# For compatibility with external Keras, we use the deprecated
# accessor here.
outbound_layer.inbound_nodes.append(self)
def iterate_inbound(self, include_arguments=False):
"""Returns a list of tuples representing the inbound data.
Arguments:
include_arguments: Whether to also iterate over any Keras Tensors
passed as args, kwargs.
Returns:
List of tuples like: (inbound_layer, node_index, tensor_index, tensor).
"""
inputs_inbound = list(
zip(
nest.flatten(self.inbound_layers),
nest.flatten(self.node_indices),
nest.flatten(self.tensor_indices),
nest.flatten(self.input_tensors)))
if include_arguments:
keras_tensor_arguments = [
kt for kt in nest.flatten(self.arguments)
if hasattr(kt, '_keras_history')
]
def _get_inbound(keras_tensor):
kh = keras_tensor._keras_history
return kh.layer, kh.node_index, kh.tensor_index, keras_tensor
arguments_inbound = nest.map_structure(_get_inbound,
keras_tensor_arguments)
return inputs_inbound + arguments_inbound
else:
return inputs_inbound
def _get_all_node_dependencies(self):
"""Returns all of the nodes this node immediately depends on."""
node_deps = []
for layer, node_index, _, _ in self.iterate_inbound():
node_deps.append(layer._inbound_nodes[node_index])
for arg in nest.flatten(self.arguments):
if isinstance(arg, ops.Tensor) and hasattr(arg, '_keras_history'):
kh = arg._keras_history
node_deps.append(kh.layer._inbound_nodes[kh.node_index])
return node_deps
def get_config(self):
inbound_names = nest.map_structure(
lambda layer: layer.name if layer else None, self.inbound_layers)
return {
'outbound_layer': self.outbound_layer.name,
'inbound_layers': inbound_names,
'node_indices': self.node_indices,
'tensor_indices': self.tensor_indices
}

View File

@ -1,198 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras initializer serialization / deserialization."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python import tf2
from tensorflow.python.framework import dtypes
from tensorflow.python.frozen_keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.frozen_keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import init_ops_v2
# These imports are brought in so that keras.initializers.deserialize
# has them available in module_objects.
from tensorflow.python.ops.init_ops import Constant
from tensorflow.python.ops.init_ops import GlorotNormal
from tensorflow.python.ops.init_ops import GlorotUniform
from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Identity
from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Ones
from tensorflow.python.ops.init_ops import Orthogonal
from tensorflow.python.ops.init_ops import RandomNormal as TFRandomNormal
from tensorflow.python.ops.init_ops import RandomUniform as TFRandomUniform
from tensorflow.python.ops.init_ops import TruncatedNormal as TFTruncatedNormal
from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Zeros
# pylint: disable=unused-import, disable=line-too-long
from tensorflow.python.ops.init_ops_v2 import Constant as ConstantV2
from tensorflow.python.ops.init_ops_v2 import GlorotNormal as GlorotNormalV2
from tensorflow.python.ops.init_ops_v2 import GlorotUniform as GlorotUniformV2
from tensorflow.python.ops.init_ops_v2 import he_normal as he_normalV2
from tensorflow.python.ops.init_ops_v2 import he_uniform as he_uniformV2
from tensorflow.python.ops.init_ops_v2 import Identity as IdentityV2
from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2
from tensorflow.python.ops.init_ops_v2 import lecun_normal as lecun_normalV2
from tensorflow.python.ops.init_ops_v2 import lecun_uniform as lecun_uniformV2
from tensorflow.python.ops.init_ops_v2 import Ones as OnesV2
from tensorflow.python.ops.init_ops_v2 import Orthogonal as OrthogonalV2
from tensorflow.python.ops.init_ops_v2 import RandomNormal as RandomNormalV2
from tensorflow.python.ops.init_ops_v2 import RandomUniform as RandomUniformV2
from tensorflow.python.ops.init_ops_v2 import TruncatedNormal as TruncatedNormalV2
from tensorflow.python.ops.init_ops_v2 import VarianceScaling as VarianceScalingV2
from tensorflow.python.ops.init_ops_v2 import Zeros as ZerosV2
# pylint: enable=unused-import, enable=line-too-long
class TruncatedNormal(TFTruncatedNormal):
"""Initializer that generates a truncated normal distribution.
These values are similar to values from a `random_normal_initializer`
except that values more than two standard deviations from the mean
are discarded and re-drawn. This is the recommended initializer for
neural network weights and filters.
Args:
mean: a python scalar or a scalar tensor. Mean of the random values to
generate. Defaults to 0.
stddev: a python scalar or a scalar tensor. Standard deviation of the random
values to generate. Defaults to 0.05.
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: The data type. Only floating point types are supported.
Returns:
A TruncatedNormal instance.
"""
def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
super(TruncatedNormal, self).__init__(
mean=mean, stddev=stddev, seed=seed, dtype=dtype)
class RandomUniform(TFRandomUniform):
"""Initializer that generates tensors with a uniform distribution.
Args:
minval: A python scalar or a scalar tensor. Lower bound of the range of
random values to generate. Defaults to -0.05.
maxval: A python scalar or a scalar tensor. Upper bound of the range of
random values to generate. Defaults to 0.05.
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: The data type.
Returns:
A RandomUniform instance.
"""
def __init__(self, minval=-0.05, maxval=0.05, seed=None,
dtype=dtypes.float32):
super(RandomUniform, self).__init__(
minval=minval, maxval=maxval, seed=seed, dtype=dtype)
class RandomNormal(TFRandomNormal):
"""Initializer that generates tensors with a normal distribution.
Args:
mean: a python scalar or a scalar tensor. Mean of the random values to
generate. Defaults to 0.
stddev: a python scalar or a scalar tensor. Standard deviation of the random
values to generate. Defaults to 0.05.
seed: A Python integer. Used to create random seeds. See
`tf.compat.v1.set_random_seed` for behavior.
dtype: The data type. Only floating point types are supported.
Returns:
RandomNormal instance.
"""
def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
super(RandomNormal, self).__init__(
mean=mean, stddev=stddev, seed=seed, dtype=dtype)
# Compatibility aliases
# pylint: disable=invalid-name
zero = zeros = Zeros
one = ones = Ones
constant = Constant
uniform = random_uniform = RandomUniform
normal = random_normal = RandomNormal
truncated_normal = TruncatedNormal
identity = Identity
orthogonal = Orthogonal
glorot_normal = GlorotNormal
glorot_uniform = GlorotUniform
# Utility functions
def serialize(initializer):
return serialize_keras_object(initializer)
def deserialize(config, custom_objects=None):
"""Return an `Initializer` object from its config."""
if tf2.enabled():
# Class names are the same for V1 and V2 but the V2 classes
# are aliased in this file so we need to grab them directly
# from `init_ops_v2`.
module_objects = {
obj_name: getattr(init_ops_v2, obj_name)
for obj_name in dir(init_ops_v2)
}
else:
module_objects = globals()
return deserialize_keras_object(
config,
module_objects=module_objects,
custom_objects=custom_objects,
printable_module_name='initializer')
def get(identifier):
if identifier is None:
return None
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, six.string_types):
identifier = str(identifier)
# We have to special-case functions that return classes.
# TODO(omalleyt): Turn these into classes or class aliases.
special_cases = ['he_normal', 'he_uniform', 'lecun_normal', 'lecun_uniform']
if identifier in special_cases:
# Treat like a class.
return deserialize({'class_name': identifier, 'config': {}})
return deserialize(identifier)
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret initializer identifier: ' +
str(identifier))
# pylint: enable=invalid-name

View File

@ -1,266 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Built-in regularizers."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.frozen_keras import backend as K
from tensorflow.python.frozen_keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.frozen_keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import math_ops
class Regularizer(object):
"""Regularizer base class.
Regularizers allow you to apply penalties on layer parameters or layer
activity during optimization. These penalties are summed into the loss
function that the network optimizes.
Regularization penalties are applied on a per-layer basis. The exact API will
depend on the layer, but many layers (e.g. `Dense`, `Conv1D`, `Conv2D` and
`Conv3D`) have a unified API.
These layers expose 3 keyword arguments:
- `kernel_regularizer`: Regularizer to apply a penalty on the layer's kernel
- `bias_regularizer`: Regularizer to apply a penalty on the layer's bias
- `activity_regularizer`: Regularizer to apply a penalty on the layer's output
All layers (including custom layers) expose `activity_regularizer` as a
settable property, whether or not it is in the constructor arguments.
The value returned by the `activity_regularizer` is divided by the input
batch size so that the relative weighting between the weight regularizers and
the activity regularizers does not change with the batch size.
You can access a layer's regularization penalties by calling `layer.losses`
after calling the layer on inputs.
## Example
>>> layer = tf.keras.layers.Dense(
... 5, input_dim=5,
... kernel_initializer='ones',
... kernel_regularizer=tf.keras.regularizers.l1(0.01),
... activity_regularizer=tf.keras.regularizers.l2(0.01))
>>> tensor = tf.ones(shape=(5, 5)) * 2.0
>>> out = layer(tensor)
>>> # The kernel regularization term is 0.25
>>> # The activity regularization term (after dividing by the batch size) is 5
>>> tf.math.reduce_sum(layer.losses)
<tf.Tensor: shape=(), dtype=float32, numpy=5.25>
## Available penalties
```python
tf.keras.regularizers.l1(0.3) # L1 Regularization Penalty
tf.keras.regularizers.l2(0.1) # L2 Regularization Penalty
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01) # L1 + L2 penalties
```
## Directly calling a regularizer
Compute a regularization loss on a tensor by directly calling a regularizer
as if it is a one-argument function.
E.g.
>>> regularizer = tf.keras.regularizers.l2(2.)
>>> tensor = tf.ones(shape=(5, 5))
>>> regularizer(tensor)
<tf.Tensor: shape=(), dtype=float32, numpy=50.0>
### A note on serialization and deserialization:
Registering the regularizers as serializable is optional if you are just
training and executing models, exporting to and from SavedModels, or saving
and loading weight checkpoints.
Registration is required for Keras `model_to_estimator`, saving and
loading models to HDF5 formats, Keras model cloning, some visualization
utilities, and exporting models to and from JSON. If using this functionality,
you must make sure any python process running your model has also defined
and registered your custom regularizer.
`tf.keras.utils.register_keras_serializable` is only available in TF 2.1 and
beyond. In earlier versions of TensorFlow you must pass your custom
regularizer to the `custom_objects` argument of methods that expect custom
regularizers to be registered as serializable.
"""
def __call__(self, x):
"""Compute a regularization penalty from an input tensor."""
return 0.
@classmethod
def from_config(cls, config):
"""Creates a regularizer from its config.
This method is the reverse of `get_config`,
capable of instantiating the same regularizer from the config
dictionary.
This method is used by Keras `model_to_estimator`, saving and
loading models to HDF5 formats, Keras model cloning, some visualization
utilities, and exporting models to and from JSON.
Arguments:
config: A Python dictionary, typically the output of get_config.
Returns:
A regularizer instance.
"""
return cls(**config)
def get_config(self):
"""Returns the config of the regularizer.
An regularizer config is a Python dictionary (serializable)
containing all configuration parameters of the regularizer.
The same regularizer can be reinstantiated later
(without any saved state) from this configuration.
This method is optional if you are just training and executing models,
exporting to and from SavedModels, or using weight checkpoints.
This method is required for Keras `model_to_estimator`, saving and
loading models to HDF5 formats, Keras model cloning, some visualization
utilities, and exporting models to and from JSON.
Returns:
Python dictionary.
"""
raise NotImplementedError(str(self) + ' does not implement get_config()')
class L1L2(Regularizer):
r"""A regularizer that applies both L1 and L2 regularization penalties.
The L1 regularization penalty is computed as:
$$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$
The L2 regularization penalty is computed as
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
Attributes:
l1: Float; L1 regularization factor.
l2: Float; L2 regularization factor.
"""
def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name
self.l1 = K.cast_to_floatx(l1)
self.l2 = K.cast_to_floatx(l2)
def __call__(self, x):
if not self.l1 and not self.l2:
return K.constant(0.)
regularization = 0.
if self.l1:
regularization += self.l1 * math_ops.reduce_sum(math_ops.abs(x))
if self.l2:
regularization += self.l2 * math_ops.reduce_sum(math_ops.square(x))
return regularization
def get_config(self):
return {'l1': float(self.l1), 'l2': float(self.l2)}
# Aliases.
def l1(l=0.01):
r"""Create a regularizer that applies an L1 regularization penalty.
The L1 regularization penalty is computed as:
$$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$
Arguments:
l: Float; L1 regularization factor.
Returns:
An L1 Regularizer with the given regularization factor.
"""
return L1L2(l1=l)
def l2(l=0.01):
r"""Create a regularizer that applies an L2 regularization penalty.
The L2 regularization penalty is computed as:
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
Arguments:
l: Float; L2 regularization factor.
Returns:
An L2 Regularizer with the given regularization factor.
"""
return L1L2(l2=l)
def l1_l2(l1=0.01, l2=0.01): # pylint: disable=redefined-outer-name
r"""Create a regularizer that applies both L1 and L2 penalties.
The L1 regularization penalty is computed as:
$$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$
The L2 regularization penalty is computed as:
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
Arguments:
l1: Float; L1 regularization factor.
l2: Float; L2 regularization factor.
Returns:
An L1L2 Regularizer with the given regularization factors.
"""
return L1L2(l1=l1, l2=l2)
def serialize(regularizer):
return serialize_keras_object(regularizer)
def deserialize(config, custom_objects=None):
return deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='regularizer')
def get(identifier):
if identifier is None:
return None
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, six.string_types):
identifier = str(identifier)
# We have to special-case functions that return classes.
# TODO(omalleyt): Turn these into classes or class aliases.
special_cases = ['l1', 'l2', 'l1_l2']
if identifier in special_cases:
# Treat like a class.
return deserialize({'class_name': identifier, 'config': {}})
return deserialize(str(identifier))
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret regularizer identifier:', identifier)

View File

@ -1,106 +0,0 @@
load("//tensorflow:tensorflow.bzl", "tf_py_test")
package(
default_visibility = ["//tensorflow:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "tf_utils",
srcs = ["tf_utils.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:composite_tensor",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:smart_cond",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tensor_util",
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/data/experimental/ops:cardinality",
"//tensorflow/python/eager:context",
"//tensorflow/python/frozen_keras:backend",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_library(
name = "conv_utils",
srcs = [
"conv_utils.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/frozen_keras:backend",
],
)
py_library(
name = "generic_utils",
srcs = [
"generic_utils.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)
py_library(
name = "layer_utils",
srcs = [
"layer_utils.py",
],
srcs_version = "PY2AND3",
deps = [
":conv_utils",
"//tensorflow/python:util",
"//tensorflow/python/frozen_keras:backend",
"//third_party/py/numpy",
],
)
tf_py_test(
name = "generic_utils_test",
size = "small",
srcs = ["generic_utils_test.py"],
python_version = "PY3",
deps = [
":generic_utils",
"//tensorflow/python:client_testlib",
"//tensorflow/python/frozen_keras:regularizers",
"//tensorflow/python/keras",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "tf_utils_test",
size = "small",
srcs = ["tf_utils_test.py"],
python_version = "PY3",
deps = [
":tf_utils",
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
],
)
tf_py_test(
name = "conv_utils_test",
size = "small",
srcs = ["conv_utils_test.py"],
python_version = "PY3",
deps = [
":conv_utils",
"//tensorflow/python:client_testlib",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -1,482 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities used by convolution layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from tensorflow.python.frozen_keras import backend
def convert_data_format(data_format, ndim):
if data_format == 'channels_last':
if ndim == 3:
return 'NWC'
elif ndim == 4:
return 'NHWC'
elif ndim == 5:
return 'NDHWC'
else:
raise ValueError('Input rank not supported:', ndim)
elif data_format == 'channels_first':
if ndim == 3:
return 'NCW'
elif ndim == 4:
return 'NCHW'
elif ndim == 5:
return 'NCDHW'
else:
raise ValueError('Input rank not supported:', ndim)
else:
raise ValueError('Invalid data_format:', data_format)
def normalize_tuple(value, n, name):
"""Transforms a single integer or iterable of integers into an integer tuple.
Arguments:
value: The value to validate and convert. Could an int, or any iterable of
ints.
n: The size of the tuple to be returned.
name: The name of the argument being validated, e.g. "strides" or
"kernel_size". This is only used to format error messages.
Returns:
A tuple of n integers.
Raises:
ValueError: If something else than an int/long or iterable thereof was
passed.
"""
if isinstance(value, int):
return (value,) * n
else:
try:
value_tuple = tuple(value)
except TypeError:
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value))
if len(value_tuple) != n:
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value))
for single_value in value_tuple:
try:
int(single_value)
except (ValueError, TypeError):
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value) + ' '
'including element ' + str(single_value) + ' of type' +
' ' + str(type(single_value)))
return value_tuple
def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
"""Determines output length of a convolution given input length.
Arguments:
input_length: integer.
filter_size: integer.
padding: one of "same", "valid", "full", "causal"
stride: integer.
dilation: dilation rate, integer.
Returns:
The output length (integer).
"""
if input_length is None:
return None
assert padding in {'same', 'valid', 'full', 'causal'}
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
if padding in ['same', 'causal']:
output_length = input_length
elif padding == 'valid':
output_length = input_length - dilated_filter_size + 1
elif padding == 'full':
output_length = input_length + dilated_filter_size - 1
return (output_length + stride - 1) // stride
def conv_input_length(output_length, filter_size, padding, stride):
"""Determines input length of a convolution given output length.
Arguments:
output_length: integer.
filter_size: integer.
padding: one of "same", "valid", "full".
stride: integer.
Returns:
The input length (integer).
"""
if output_length is None:
return None
assert padding in {'same', 'valid', 'full'}
if padding == 'same':
pad = filter_size // 2
elif padding == 'valid':
pad = 0
elif padding == 'full':
pad = filter_size - 1
return (output_length - 1) * stride - 2 * pad + filter_size
def deconv_output_length(input_length,
filter_size,
padding,
output_padding=None,
stride=0,
dilation=1):
"""Determines output length of a transposed convolution given input length.
Arguments:
input_length: Integer.
filter_size: Integer.
padding: one of `"same"`, `"valid"`, `"full"`.
output_padding: Integer, amount of padding along the output dimension. Can
be set to `None` in which case the output length is inferred.
stride: Integer.
dilation: Integer.
Returns:
The output length (integer).
"""
assert padding in {'same', 'valid', 'full'}
if input_length is None:
return None
# Get the dilated kernel size
filter_size = filter_size + (filter_size - 1) * (dilation - 1)
# Infer length if output padding is None, else compute the exact length
if output_padding is None:
if padding == 'valid':
length = input_length * stride + max(filter_size - stride, 0)
elif padding == 'full':
length = input_length * stride - (stride + filter_size - 2)
elif padding == 'same':
length = input_length * stride
else:
if padding == 'same':
pad = filter_size // 2
elif padding == 'valid':
pad = 0
elif padding == 'full':
pad = filter_size - 1
length = ((input_length - 1) * stride + filter_size - 2 * pad +
output_padding)
return length
def normalize_data_format(value):
if value is None:
value = backend.image_data_format()
data_format = value.lower()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('The `data_format` argument must be one of '
'"channels_first", "channels_last". Received: ' +
str(value))
return data_format
def normalize_padding(value):
if isinstance(value, (list, tuple)):
return value
padding = value.lower()
if padding not in {'valid', 'same', 'causal'}:
raise ValueError('The `padding` argument must be a list/tuple or one of '
'"valid", "same" (or "causal", only for `Conv1D). '
'Received: ' + str(padding))
return padding
def convert_kernel(kernel):
"""Converts a Numpy kernel matrix from Theano format to TensorFlow format.
Also works reciprocally, since the transformation is its own inverse.
This is used for converting legacy Theano-saved model files.
Arguments:
kernel: Numpy array (3D, 4D or 5D).
Returns:
The converted kernel.
Raises:
ValueError: in case of invalid kernel shape or invalid data_format.
"""
kernel = np.asarray(kernel)
if not 3 <= kernel.ndim <= 5:
raise ValueError('Invalid kernel shape:', kernel.shape)
slices = [slice(None, None, -1) for _ in range(kernel.ndim)]
no_flip = (slice(None, None), slice(None, None))
slices[-2:] = no_flip
return np.copy(kernel[slices])
def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
"""Compute a mask representing the connectivity of a convolution operation.
Assume a convolution with given parameters is applied to an input having N
spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array
of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries
indicating pairs of input and output locations that are connected by a weight.
Example:
>>> input_shape = (4,)
>>> kernel_shape = (2,)
>>> strides = (1,)
>>> padding = "valid"
>>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
array([[ True, False, False],
[ True, True, False],
[False, True, True],
[False, False, True]])
where rows and columns correspond to inputs and outputs respectively.
Args:
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
input.
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
receptive field.
strides: tuple of size N, strides along each spatial dimension.
padding: type of padding, string `"same"` or `"valid"`.
Returns:
A boolean 2N-D `np.ndarray` of shape
`(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
is the spatial shape of the output. `True` entries in the mask represent
pairs of input-output locations that are connected by a weight.
Raises:
ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
same number of dimensions.
NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
"""
if padding not in {'same', 'valid'}:
raise NotImplementedError('Padding type %s not supported. '
'Only "valid" and "same" '
'are implemented.' % padding)
in_dims = len(input_shape)
if isinstance(kernel_shape, int):
kernel_shape = (kernel_shape,) * in_dims
if isinstance(strides, int):
strides = (strides,) * in_dims
kernel_dims = len(kernel_shape)
stride_dims = len(strides)
if kernel_dims != in_dims or stride_dims != in_dims:
raise ValueError('Number of strides, input and kernel dimensions must all '
'match. Received: %d, %d, %d.' %
(stride_dims, in_dims, kernel_dims))
output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
mask_shape = input_shape + output_shape
mask = np.zeros(mask_shape, np.bool)
output_axes_ticks = [range(dim) for dim in output_shape]
for output_position in itertools.product(*output_axes_ticks):
input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
output_position, strides, padding)
for input_position in itertools.product(*input_axes_ticks):
mask[input_position + output_position] = True
return mask
def conv_kernel_idxs(input_shape, kernel_shape, strides, padding, filters_in,
filters_out, data_format):
"""Yields output-input tuples of indices in a CNN layer.
The generator iterates over all `(output_idx, input_idx)` tuples, where
`output_idx` is an integer index in a flattened tensor representing a single
output image of a convolutional layer that is connected (via the layer
weights) to the respective single input image at `input_idx`
Example:
>>> input_shape = (2, 2)
>>> kernel_shape = (2, 1)
>>> strides = (1, 1)
>>> padding = "valid"
>>> filters_in = 1
>>> filters_out = 1
>>> data_format = "channels_last"
>>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding,
... filters_in, filters_out, data_format))
[(0, 0), (0, 2), (1, 1), (1, 3)]
Args:
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
input.
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
receptive field.
strides: tuple of size N, strides along each spatial dimension.
padding: type of padding, string `"same"` or `"valid"`.
filters_in: `int`, number if filters in the input to the layer.
filters_out: `int', number if filters in the output of the layer.
data_format: string, "channels_first" or "channels_last".
Yields:
The next tuple `(output_idx, input_idx)`, where
`output_idx` is an integer index in a flattened tensor representing a single
output image of a convolutional layer that is connected (via the layer
weights) to the respective single input image at `input_idx`.
Raises:
ValueError: if `data_format` is neither
`"channels_last"` nor `"channels_first"`, or if number of strides, input,
and kernel number of dimensions do not match.
NotImplementedError: if `padding` is neither `"same"` nor `"valid"`.
"""
if padding not in ('same', 'valid'):
raise NotImplementedError('Padding type %s not supported. '
'Only "valid" and "same" '
'are implemented.' % padding)
in_dims = len(input_shape)
if isinstance(kernel_shape, int):
kernel_shape = (kernel_shape,) * in_dims
if isinstance(strides, int):
strides = (strides,) * in_dims
kernel_dims = len(kernel_shape)
stride_dims = len(strides)
if kernel_dims != in_dims or stride_dims != in_dims:
raise ValueError('Number of strides, input and kernel dimensions must all '
'match. Received: %d, %d, %d.' %
(stride_dims, in_dims, kernel_dims))
output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
output_axes_ticks = [range(dim) for dim in output_shape]
if data_format == 'channels_first':
concat_idxs = lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx
elif data_format == 'channels_last':
concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (filter_idx,)
else:
raise ValueError('Data format %s not recognized.'
'`data_format` must be "channels_first" or '
'"channels_last".' % data_format)
for output_position in itertools.product(*output_axes_ticks):
input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
output_position, strides, padding)
for input_position in itertools.product(*input_axes_ticks):
for f_in in range(filters_in):
for f_out in range(filters_out):
out_idx = np.ravel_multi_index(
multi_index=concat_idxs(output_position, f_out),
dims=concat_idxs(output_shape, filters_out))
in_idx = np.ravel_multi_index(
multi_index=concat_idxs(input_position, f_in),
dims=concat_idxs(input_shape, filters_in))
yield (out_idx, in_idx)
def conv_connected_inputs(input_shape, kernel_shape, output_position, strides,
padding):
"""Return locations of the input connected to an output position.
Assume a convolution with given parameters is applied to an input having N
spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
returns N ranges specifying the input region that was convolved with the
kernel to produce the output at position
`output_position = (p_out1, ..., p_outN)`.
Example:
>>> input_shape = (4, 4)
>>> kernel_shape = (2, 1)
>>> output_position = (1, 1)
>>> strides = (1, 1)
>>> padding = "valid"
>>> conv_connected_inputs(input_shape, kernel_shape, output_position,
... strides, padding)
[range(1, 3), range(1, 2)]
Args:
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
input.
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
receptive field.
output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single position
in the output of the convolution.
strides: tuple of size N, strides along each spatial dimension.
padding: type of padding, string `"same"` or `"valid"`.
Returns:
N ranges `[[p_in_left1, ..., p_in_right1], ...,
[p_in_leftN, ..., p_in_rightN]]` specifying the region in the
input connected to output_position.
"""
ranges = []
ndims = len(input_shape)
for d in range(ndims):
left_shift = int(kernel_shape[d] / 2)
right_shift = kernel_shape[d] - left_shift
center = output_position[d] * strides[d]
if padding == 'valid':
center += left_shift
start = max(0, center - left_shift)
end = min(input_shape[d], center + right_shift)
ranges.append(range(start, end))
return ranges
def conv_output_shape(input_shape, kernel_shape, strides, padding):
"""Return the output shape of an N-D convolution.
Forces dimensions where input is empty (size 0) to remain empty.
Args:
input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
input.
kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
receptive field.
strides: tuple of size N, strides along each spatial dimension.
padding: type of padding, string `"same"` or `"valid"`.
Returns:
tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
"""
dims = range(len(kernel_shape))
output_shape = [
conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d])
for d in dims
]
output_shape = tuple(
[0 if input_shape[d] == 0 else output_shape[d] for d in dims])
return output_shape

View File

@ -1,340 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for conv_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.frozen_keras.utils import conv_utils
from tensorflow.python.platform import test
def _get_const_output_shape(input_shape, dim):
return tuple([min(d, dim) for d in input_shape])
input_shapes = [
(0,),
(0, 0),
(1,),
(2,),
(3,),
(1, 0),
(0, 3),
(1, 1),
(1, 2),
(3, 1),
(2, 2),
(3, 3),
(1, 0, 1),
(5, 2, 3),
(3, 5, 6, 7, 0),
(3, 2, 2, 4, 4),
(1, 2, 3, 4, 7, 2),
]
class TestBasicConvUtilsTest(test.TestCase):
def test_convert_data_format(self):
self.assertEqual('NCDHW', conv_utils.convert_data_format(
'channels_first', 5))
self.assertEqual('NCHW', conv_utils.convert_data_format(
'channels_first', 4))
self.assertEqual('NCW', conv_utils.convert_data_format('channels_first', 3))
self.assertEqual('NHWC', conv_utils.convert_data_format('channels_last', 4))
self.assertEqual('NWC', conv_utils.convert_data_format('channels_last', 3))
self.assertEqual('NDHWC', conv_utils.convert_data_format(
'channels_last', 5))
with self.assertRaises(ValueError):
conv_utils.convert_data_format('invalid', 2)
def test_normalize_tuple(self):
self.assertEqual((2, 2, 2),
conv_utils.normalize_tuple(2, n=3, name='strides'))
self.assertEqual((2, 1, 2),
conv_utils.normalize_tuple((2, 1, 2), n=3, name='strides'))
with self.assertRaises(ValueError):
conv_utils.normalize_tuple((2, 1), n=3, name='strides')
with self.assertRaises(ValueError):
conv_utils.normalize_tuple(None, n=3, name='strides')
def test_normalize_data_format(self):
self.assertEqual('channels_last',
conv_utils.normalize_data_format('Channels_Last'))
self.assertEqual('channels_first',
conv_utils.normalize_data_format('CHANNELS_FIRST'))
with self.assertRaises(ValueError):
conv_utils.normalize_data_format('invalid')
def test_normalize_padding(self):
self.assertEqual('same', conv_utils.normalize_padding('SAME'))
self.assertEqual('valid', conv_utils.normalize_padding('VALID'))
with self.assertRaises(ValueError):
conv_utils.normalize_padding('invalid')
def test_conv_output_length(self):
self.assertEqual(4, conv_utils.conv_output_length(4, 2, 'same', 1, 1))
self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'same', 2, 1))
self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'valid', 1, 1))
self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'valid', 2, 1))
self.assertEqual(5, conv_utils.conv_output_length(4, 2, 'full', 1, 1))
self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'full', 2, 1))
self.assertEqual(2, conv_utils.conv_output_length(5, 2, 'valid', 2, 2))
def test_conv_input_length(self):
self.assertEqual(3, conv_utils.conv_input_length(4, 2, 'same', 1))
self.assertEqual(2, conv_utils.conv_input_length(2, 2, 'same', 2))
self.assertEqual(4, conv_utils.conv_input_length(3, 2, 'valid', 1))
self.assertEqual(4, conv_utils.conv_input_length(2, 2, 'valid', 2))
self.assertEqual(3, conv_utils.conv_input_length(4, 2, 'full', 1))
self.assertEqual(4, conv_utils.conv_input_length(3, 2, 'full', 2))
def test_deconv_output_length(self):
self.assertEqual(4, conv_utils.deconv_output_length(4, 2, 'same', stride=1))
self.assertEqual(8, conv_utils.deconv_output_length(4, 2, 'same', stride=2))
self.assertEqual(5, conv_utils.deconv_output_length(
4, 2, 'valid', stride=1))
self.assertEqual(8, conv_utils.deconv_output_length(
4, 2, 'valid', stride=2))
self.assertEqual(3, conv_utils.deconv_output_length(4, 2, 'full', stride=1))
self.assertEqual(6, conv_utils.deconv_output_length(4, 2, 'full', stride=2))
self.assertEqual(
5,
conv_utils.deconv_output_length(
4, 2, 'same', output_padding=2, stride=1))
self.assertEqual(
7,
conv_utils.deconv_output_length(
4, 2, 'same', output_padding=1, stride=2))
self.assertEqual(
7,
conv_utils.deconv_output_length(
4, 2, 'valid', output_padding=2, stride=1))
self.assertEqual(
9,
conv_utils.deconv_output_length(
4, 2, 'valid', output_padding=1, stride=2))
self.assertEqual(
5,
conv_utils.deconv_output_length(
4, 2, 'full', output_padding=2, stride=1))
self.assertEqual(
7,
conv_utils.deconv_output_length(
4, 2, 'full', output_padding=1, stride=2))
self.assertEqual(
5,
conv_utils.deconv_output_length(
4, 2, 'same', output_padding=1, stride=1, dilation=2))
self.assertEqual(
12,
conv_utils.deconv_output_length(
4, 2, 'valid', output_padding=2, stride=2, dilation=3))
self.assertEqual(
6,
conv_utils.deconv_output_length(
4, 2, 'full', output_padding=2, stride=2, dilation=3))
@parameterized.parameters(input_shapes)
class TestConvUtils(test.TestCase, parameterized.TestCase):
def test_conv_kernel_mask_fc(self, *input_shape):
padding = 'valid'
kernel_shape = input_shape
ndims = len(input_shape)
strides = (1,) * ndims
output_shape = _get_const_output_shape(input_shape, dim=1)
mask = np.ones(input_shape + output_shape, np.bool)
self.assertAllEqual(
mask,
conv_utils.conv_kernel_mask(
input_shape,
kernel_shape,
strides,
padding
)
)
def test_conv_kernel_mask_diag(self, *input_shape):
ndims = len(input_shape)
kernel_shape = (1,) * ndims
strides = (1,) * ndims
for padding in ['valid', 'same']:
mask = np.identity(int(np.prod(input_shape)), np.bool)
mask = np.reshape(mask, input_shape * 2)
self.assertAllEqual(
mask,
conv_utils.conv_kernel_mask(
input_shape,
kernel_shape,
strides,
padding
)
)
def test_conv_kernel_mask_full_stride(self, *input_shape):
padding = 'valid'
ndims = len(input_shape)
kernel_shape = (1,) * ndims
strides = tuple([max(d, 1) for d in input_shape])
output_shape = _get_const_output_shape(input_shape, dim=1)
mask = np.zeros(input_shape + output_shape, np.bool)
if all(d > 0 for d in mask.shape):
mask[(0,) * len(output_shape)] = True
self.assertAllEqual(
mask,
conv_utils.conv_kernel_mask(
input_shape,
kernel_shape,
strides,
padding
)
)
def test_conv_kernel_mask_almost_full_stride(self, *input_shape):
padding = 'valid'
ndims = len(input_shape)
kernel_shape = (1,) * ndims
strides = tuple([max(d - 1, 1) for d in input_shape])
output_shape = _get_const_output_shape(input_shape, dim=2)
mask = np.zeros(input_shape + output_shape, np.bool)
if all(d > 0 for d in mask.shape):
for in_position in itertools.product(*[[0, d - 1] for d in input_shape]):
out_position = tuple([min(p, 1) for p in in_position])
mask[in_position + out_position] = True
self.assertAllEqual(
mask,
conv_utils.conv_kernel_mask(
input_shape,
kernel_shape,
strides,
padding
)
)
def test_conv_kernel_mask_rect_kernel(self, *input_shape):
padding = 'valid'
ndims = len(input_shape)
strides = (1,) * ndims
for d in range(ndims):
kernel_shape = [1] * ndims
kernel_shape[d] = input_shape[d]
output_shape = list(input_shape)
output_shape[d] = min(1, input_shape[d])
mask = np.identity(int(np.prod(input_shape)), np.bool)
mask = np.reshape(mask, input_shape * 2)
for p in itertools.product(*[range(input_shape[dim])
for dim in range(ndims)]):
p = list(p)
p[d] = slice(None)
mask[p * 2] = True
mask = np.take(mask, range(0, min(1, input_shape[d])), ndims + d)
self.assertAllEqual(
mask,
conv_utils.conv_kernel_mask(
input_shape,
kernel_shape,
strides,
padding
)
)
def test_conv_kernel_mask_wrong_padding(self, *input_shape):
ndims = len(input_shape)
kernel_shape = (1,) * ndims
strides = (1,) * ndims
conv_utils.conv_kernel_mask(
input_shape,
kernel_shape,
strides,
'valid'
)
conv_utils.conv_kernel_mask(
input_shape,
kernel_shape,
strides,
'same'
)
self.assertRaises(NotImplementedError,
conv_utils.conv_kernel_mask,
input_shape, kernel_shape, strides, 'full')
def test_conv_kernel_mask_wrong_dims(self, *input_shape):
kernel_shape = 1
strides = 1
conv_utils.conv_kernel_mask(
input_shape,
kernel_shape,
strides,
'valid'
)
ndims = len(input_shape)
kernel_shape = (2,) * (ndims + 1)
self.assertRaises(ValueError,
conv_utils.conv_kernel_mask,
input_shape, kernel_shape, strides, 'same')
strides = (1,) * ndims
self.assertRaises(ValueError,
conv_utils.conv_kernel_mask,
input_shape, kernel_shape, strides, 'valid')
kernel_shape = (1,) * ndims
strides = (2,) * (ndims - 1)
self.assertRaises(ValueError,
conv_utils.conv_kernel_mask,
input_shape, kernel_shape, strides, 'valid')
strides = (2,) * ndims
conv_utils.conv_kernel_mask(
input_shape,
kernel_shape,
strides,
'valid'
)
if __name__ == '__main__':
test.main()

View File

@ -1,612 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python utilities required by Keras."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import binascii
import codecs
import marshal
import os
import re
import types as python_types
import numpy as np
import six
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
_GLOBAL_CUSTOM_OBJECTS = {}
_GLOBAL_CUSTOM_NAMES = {}
# Flag that determines whether to skip the NotImplementedError when calling
# get_config in custom models and layers. This is only enabled when saving to
# SavedModel, when the config isn't required.
_SKIP_FAILED_SERIALIZATION = False
# If a layer does not have a defined config, then the returned config will be a
# dictionary with the below key.
_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
class CustomObjectScope(object):
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
Code within a `with` statement will be able to access custom objects
by name. Changes to global custom objects persist
within the enclosing `with` statement. At end of the `with` statement,
global custom objects are reverted to state
at beginning of the `with` statement.
Example:
Consider a custom object `MyObject` (e.g. a class):
```python
with CustomObjectScope({'MyObject':MyObject}):
layer = Dense(..., kernel_regularizer='MyObject')
# save, load, etc. will recognize custom object by name
```
"""
def __init__(self, *args):
self.custom_objects = args
self.backup = None
def __enter__(self):
self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
for objects in self.custom_objects:
_GLOBAL_CUSTOM_OBJECTS.update(objects)
return self
def __exit__(self, *args, **kwargs):
_GLOBAL_CUSTOM_OBJECTS.clear()
_GLOBAL_CUSTOM_OBJECTS.update(self.backup)
def custom_object_scope(*args):
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
Convenience wrapper for `CustomObjectScope`.
Code within a `with` statement will be able to access custom objects
by name. Changes to global custom objects persist
within the enclosing `with` statement. At end of the `with` statement,
global custom objects are reverted to state
at beginning of the `with` statement.
Example:
Consider a custom object `MyObject`
```python
with custom_object_scope({'MyObject':MyObject}):
layer = Dense(..., kernel_regularizer='MyObject')
# save, load, etc. will recognize custom object by name
```
Arguments:
*args: Variable length list of dictionaries of name, class pairs to add to
custom objects.
Returns:
Object of type `CustomObjectScope`.
"""
return CustomObjectScope(*args)
def get_custom_objects():
"""Retrieves a live reference to the global dictionary of custom objects.
Updating and clearing custom objects using `custom_object_scope`
is preferred, but `get_custom_objects` can
be used to directly access `_GLOBAL_CUSTOM_OBJECTS`.
Example:
```python
get_custom_objects().clear()
get_custom_objects()['MyObject'] = MyObject
```
Returns:
Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
"""
return _GLOBAL_CUSTOM_OBJECTS
def serialize_keras_class_and_config(cls_name, cls_config):
"""Returns the serialization of the class with the given config."""
return {'class_name': cls_name, 'config': cls_config}
def register_keras_serializable(package='Custom', name=None):
"""Registers an object with the Keras serialization framework.
This decorator injects the decorated class or function into the Keras custom
object dictionary, so that it can be serialized and deserialized without
needing an entry in the user-provided custom object dict. It also injects a
function that Keras will call to get the object's serializable string key.
Note that to be serialized and deserialized, classes must implement the
`get_config()` method. Functions do not have this requirement.
The object will be registered under the key 'package>name' where `name`,
defaults to the object name if not passed.
Arguments:
package: The package that this class belongs to.
name: The name to serialize this class under in this package. If None, the
class's name will be used.
Returns:
A decorator that registers the decorated class with the passed names.
"""
def decorator(arg):
"""Registers a class with the Keras serialization framework."""
class_name = name if name is not None else arg.__name__
registered_name = package + '>' + class_name
if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'):
raise ValueError(
'Cannot register a class that does not have a get_config() method.')
if registered_name in _GLOBAL_CUSTOM_OBJECTS:
raise ValueError(
'%s has already been registered to %s' %
(registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name]))
if arg in _GLOBAL_CUSTOM_NAMES:
raise ValueError('%s has already been registered to %s' %
(arg, _GLOBAL_CUSTOM_NAMES[arg]))
_GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
_GLOBAL_CUSTOM_NAMES[arg] = registered_name
return arg
return decorator
def get_registered_name(obj):
"""Returns the name registered to an object within the Keras framework.
This function is part of the Keras serialization and deserialization
framework. It maps objects to the string names associated with those objects
for serialization/deserialization.
Args:
obj: The object to look up.
Returns:
The name associated with the object, or the default Python name if the
object is not registered.
"""
if obj in _GLOBAL_CUSTOM_NAMES:
return _GLOBAL_CUSTOM_NAMES[obj]
else:
return obj.__name__
@tf_contextlib.contextmanager
def skip_failed_serialization():
global _SKIP_FAILED_SERIALIZATION
prev = _SKIP_FAILED_SERIALIZATION
try:
_SKIP_FAILED_SERIALIZATION = True
yield
finally:
_SKIP_FAILED_SERIALIZATION = prev
def get_registered_object(name, custom_objects=None, module_objects=None):
"""Returns the class associated with `name` if it is registered with Keras.
This function is part of the Keras serialization and deserialization
framework. It maps strings to the objects associated with them for
serialization/deserialization.
Example:
```
def from_config(cls, config, custom_objects=None):
if 'my_custom_object_name' in config:
config['hidden_cls'] = tf.keras.utils.get_registered_object(
config['my_custom_object_name'], custom_objects=custom_objects)
```
Args:
name: The name to look up.
custom_objects: A dictionary of custom objects to look the name up in.
Generally, custom_objects is provided by the user.
module_objects: A dictionary of custom objects to look the name up in.
Generally, module_objects is provided by midlevel library implementers.
Returns:
An instantiable class associated with 'name', or None if no such class
exists.
"""
if name in _GLOBAL_CUSTOM_OBJECTS:
return _GLOBAL_CUSTOM_OBJECTS[name]
elif custom_objects and name in custom_objects:
return custom_objects[name]
elif module_objects and name in module_objects:
return module_objects[name]
return None
def serialize_keras_object(instance):
"""Serialize Keras object into JSON."""
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
if hasattr(instance, 'get_config'):
name = get_registered_name(instance.__class__)
try:
config = instance.get_config()
except NotImplementedError as e:
if _SKIP_FAILED_SERIALIZATION:
return serialize_keras_class_and_config(
name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
raise e
serialization_config = {}
for key, item in config.items():
if isinstance(item, six.string_types):
serialization_config[key] = item
continue
# Any object of a different type needs to be converted to string or dict
# for serialization (e.g. custom functions, custom classes)
try:
serialized_item = serialize_keras_object(item)
if isinstance(serialized_item, dict) and not isinstance(item, dict):
serialized_item['__passive_serialization__'] = True
serialization_config[key] = serialized_item
except ValueError:
serialization_config[key] = item
name = get_registered_name(instance.__class__)
return serialize_keras_class_and_config(name, serialization_config)
if hasattr(instance, '__name__'):
return get_registered_name(instance)
raise ValueError('Cannot serialize', instance)
def get_custom_objects_by_name(item, custom_objects=None):
"""Returns the item if it is in either local or global custom objects."""
if item in _GLOBAL_CUSTOM_OBJECTS:
return _GLOBAL_CUSTOM_OBJECTS[item]
elif custom_objects and item in custom_objects:
return custom_objects[item]
return None
def class_and_config_for_serialized_keras_object(
config,
module_objects=None,
custom_objects=None,
printable_module_name='object'):
"""Returns the class name and config for a serialized keras object."""
if (not isinstance(config, dict) or 'class_name' not in config or
'config' not in config):
raise ValueError('Improper config format: ' + str(config))
class_name = config['class_name']
cls = get_registered_object(class_name, custom_objects, module_objects)
if cls is None:
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
cls_config = config['config']
deserialized_objects = {}
for key, item in cls_config.items():
if isinstance(item, dict) and '__passive_serialization__' in item:
deserialized_objects[key] = deserialize_keras_object(
item,
module_objects=module_objects,
custom_objects=custom_objects,
printable_module_name='config_item')
# TODO(momernick): Should this also have 'module_objects'?
elif (isinstance(item, six.string_types) and
tf_inspect.isfunction(get_registered_object(item, custom_objects))):
# Handle custom functions here. When saving functions, we only save the
# function's name as a string. If we find a matching string in the custom
# objects during deserialization, we convert the string back to the
# original function.
# Note that a potential issue is that a string field could have a naming
# conflict with a custom function name, but this should be a rare case.
# This issue does not occur if a string field has a naming conflict with
# a custom object, since the config of an object will always be a dict.
deserialized_objects[key] = get_registered_object(item, custom_objects)
for key, item in deserialized_objects.items():
cls_config[key] = deserialized_objects[key]
return (cls, cls_config)
def deserialize_keras_object(identifier,
module_objects=None,
custom_objects=None,
printable_module_name='object'):
if identifier is None:
return None
if isinstance(identifier, dict):
# In this case we are dealing with a Keras config dictionary.
config = identifier
(cls, cls_config) = class_and_config_for_serialized_keras_object(
config, module_objects, custom_objects, printable_module_name)
if hasattr(cls, 'from_config'):
arg_spec = tf_inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}
if 'custom_objects' in arg_spec.args:
return cls.from_config(
cls_config,
custom_objects=dict(
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
list(custom_objects.items())))
with CustomObjectScope(custom_objects):
return cls.from_config(cls_config)
else:
# Then `cls` may be a function returning a class.
# in this case by convention `config` holds
# the kwargs of the function.
custom_objects = custom_objects or {}
with CustomObjectScope(custom_objects):
return cls(**cls_config)
elif isinstance(identifier, six.string_types):
object_name = identifier
if custom_objects and object_name in custom_objects:
obj = custom_objects.get(object_name)
elif object_name in _GLOBAL_CUSTOM_OBJECTS:
obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
else:
obj = module_objects.get(object_name)
if obj is None:
raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
# Classes passed by name are instantiated with no args, functions are
# returned as-is.
if tf_inspect.isclass(obj):
return obj()
return obj
elif tf_inspect.isfunction(identifier):
# If a function has already been deserialized, return as is.
return identifier
else:
raise ValueError('Could not interpret serialized %s: %s' %
(printable_module_name, identifier))
def func_dump(func):
"""Serializes a user defined function.
Arguments:
func: the function to serialize.
Returns:
A tuple `(code, defaults, closure)`.
"""
if os.name == 'nt':
raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/')
code = codecs.encode(raw_code, 'base64').decode('ascii')
else:
raw_code = marshal.dumps(func.__code__)
code = codecs.encode(raw_code, 'base64').decode('ascii')
defaults = func.__defaults__
if func.__closure__:
closure = tuple(c.cell_contents for c in func.__closure__)
else:
closure = None
return code, defaults, closure
def func_load(code, defaults=None, closure=None, globs=None):
"""Deserializes a user defined function.
Arguments:
code: bytecode of the function.
defaults: defaults of the function.
closure: closure of the function.
globs: dictionary of global objects.
Returns:
A function object.
"""
if isinstance(code, (tuple, list)): # unpack previous dump
code, defaults, closure = code
if isinstance(defaults, list):
defaults = tuple(defaults)
def ensure_value_to_cell(value):
"""Ensures that a value is converted to a python cell object.
Arguments:
value: Any value that needs to be casted to the cell type
Returns:
A value wrapped as a cell object (see function "func_load")
"""
def dummy_fn():
# pylint: disable=pointless-statement
value # just access it so it gets captured in .__closure__
cell_value = dummy_fn.__closure__[0]
if not isinstance(value, type(cell_value)):
return cell_value
return value
if closure is not None:
closure = tuple(ensure_value_to_cell(_) for _ in closure)
try:
raw_code = codecs.decode(code.encode('ascii'), 'base64')
except (UnicodeEncodeError, binascii.Error):
raw_code = code.encode('raw_unicode_escape')
code = marshal.loads(raw_code)
if globs is None:
globs = globals()
return python_types.FunctionType(
code, globs, name=code.co_name, argdefs=defaults, closure=closure)
def has_arg(fn, name, accept_all=False):
"""Checks if a callable accepts a given keyword argument.
Arguments:
fn: Callable to inspect.
name: Check if `fn` can be called with `name` as a keyword argument.
accept_all: What to return if there is no parameter called `name` but the
function accepts a `**kwargs` argument.
Returns:
bool, whether `fn` accepts a `name` keyword argument.
"""
arg_spec = tf_inspect.getfullargspec(fn)
if accept_all and arg_spec.varkw is not None:
return True
return name in arg_spec.args
def make_batches(size, batch_size):
"""Returns a list of batch indices (tuples of indices).
Arguments:
size: Integer, total size of the data to slice into batches.
batch_size: Integer, batch size.
Returns:
A list of tuples of array indices.
"""
num_batches = int(np.ceil(size / float(batch_size)))
return [(i * batch_size, min(size, (i + 1) * batch_size))
for i in range(0, num_batches)]
def slice_arrays(arrays, start=None, stop=None):
"""Slice an array or list of arrays.
This takes an array-like, or a list of
array-likes, and outputs:
- arrays[start:stop] if `arrays` is an array-like
- [x[start:stop] for x in arrays] if `arrays` is a list
Can also work on list/array of indices: `slice_arrays(x, indices)`
Arguments:
arrays: Single array or list of arrays.
start: can be an integer index (start index) or a list/array of indices
stop: integer (stop index); should be None if `start` was a list.
Returns:
A slice of the array(s).
Raises:
ValueError: If the value of start is a list and stop is not None.
"""
if arrays is None:
return [None]
if isinstance(start, list) and stop is not None:
raise ValueError('The stop argument has to be None if the value of start '
'is a list.')
elif isinstance(arrays, list):
if hasattr(start, '__len__'):
# hdf5 datasets only support list objects as indices
if hasattr(start, 'shape'):
start = start.tolist()
return [None if x is None else x[start] for x in arrays]
return [
None if x is None else
None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays
]
else:
if hasattr(start, '__len__'):
if hasattr(start, 'shape'):
start = start.tolist()
return arrays[start]
if hasattr(start, '__getitem__'):
return arrays[start:stop]
return [None]
def to_list(x):
"""Normalizes a list/tensor into a list.
If a tensor is passed, we return
a list of size 1 containing the tensor.
Arguments:
x: target object to be normalized.
Returns:
A list.
"""
if isinstance(x, list):
return x
return [x]
def to_snake_case(name):
intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
# If the class is private the name starts with "_" which is not secure
# for creating scopes. We prefix the name with "private" in this case.
if insecure[0] != '_':
return insecure
return 'private' + insecure
def is_all_none(structure):
iterable = nest.flatten(structure)
# We cannot use Python's `any` because the iterable may return Tensors.
for element in iterable:
if element is not None:
return False
return True
def check_for_unexpected_keys(name, input_dict, expected_values):
unknown = set(input_dict.keys()).difference(expected_values)
if unknown:
raise ValueError('Unknown entries in {} dictionary: {}. Only expected '
'following keys: {}'.format(name, list(unknown),
expected_values))
def validate_kwargs(kwargs,
allowed_kwargs,
error_message='Keyword argument not understood:'):
"""Checks that all keyword arguments are in the set of allowed keys."""
for kwarg in kwargs:
if kwarg not in allowed_kwargs:
raise TypeError(error_message, kwarg)
def validate_config(config):
"""Determines whether config appears to be a valid layer config."""
return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
def default(method):
"""Decorates a method to detect overrides in subclasses."""
method._is_default = True # pylint: disable=protected-access
return method
def is_default(method):
"""Check if a method is decorated with the `default` wrapper."""
return getattr(method, '_is_default', False)

View File

@ -1,321 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras generic Python utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import keras
from tensorflow.python.frozen_keras import regularizers
from tensorflow.python.frozen_keras.utils import generic_utils
from tensorflow.python.platform import test
class HasArgTest(test.TestCase):
def test_has_arg(self):
def f_x(x):
return x
def f_x_args(x, *args):
_ = args
return x
def f_x_kwargs(x, **kwargs):
_ = kwargs
return x
self.assertTrue(generic_utils.has_arg(
f_x, 'x', accept_all=False))
self.assertFalse(generic_utils.has_arg(
f_x, 'y', accept_all=False))
self.assertTrue(generic_utils.has_arg(
f_x_args, 'x', accept_all=False))
self.assertFalse(generic_utils.has_arg(
f_x_args, 'y', accept_all=False))
self.assertTrue(generic_utils.has_arg(
f_x_kwargs, 'x', accept_all=False))
self.assertFalse(generic_utils.has_arg(
f_x_kwargs, 'y', accept_all=False))
self.assertTrue(generic_utils.has_arg(
f_x_kwargs, 'y', accept_all=True))
class TestCustomObjectScope(test.TestCase):
def test_custom_object_scope(self):
def custom_fn():
pass
class CustomClass(object):
pass
with generic_utils.custom_object_scope(
{'CustomClass': CustomClass, 'custom_fn': custom_fn}):
# Disable activation test since its not under frozen_keras package.
# act = keras.activations.get('custom_fn')
# self.assertEqual(act, custom_fn)
cl = regularizers.get('CustomClass')
self.assertEqual(cl.__class__, CustomClass)
class SerializeKerasObjectTest(test.TestCase):
def test_serialize_none(self):
serialized = generic_utils.serialize_keras_object(None)
self.assertEqual(serialized, None)
deserialized = generic_utils.deserialize_keras_object(
serialized)
self.assertEqual(deserialized, None)
def test_serialize_custom_class_with_default_name(self):
@generic_utils.register_keras_serializable()
class TestClass(object):
def __init__(self, value):
self._value = value
def get_config(self):
return {'value': self._value}
serialized_name = 'Custom>TestClass'
inst = TestClass(value=10)
class_name = generic_utils._GLOBAL_CUSTOM_NAMES[TestClass]
self.assertEqual(serialized_name, class_name)
config = generic_utils.serialize_keras_object(inst)
self.assertEqual(class_name, config['class_name'])
new_inst = generic_utils.deserialize_keras_object(config)
self.assertIsNot(inst, new_inst)
self.assertIsInstance(new_inst, TestClass)
self.assertEqual(10, new_inst._value)
# Make sure registering a new class with same name will fail.
with self.assertRaisesRegex(ValueError, '.*has already been registered.*'):
@generic_utils.register_keras_serializable() # pylint: disable=function-redefined
class TestClass(object):
def __init__(self, value):
self._value = value
def get_config(self):
return {'value': self._value}
def test_serialize_custom_class_with_custom_name(self):
@generic_utils.register_keras_serializable(
'TestPackage', 'CustomName')
class OtherTestClass(object):
def __init__(self, val):
self._val = val
def get_config(self):
return {'val': self._val}
serialized_name = 'TestPackage>CustomName'
inst = OtherTestClass(val=5)
class_name = generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
self.assertEqual(serialized_name, class_name)
fn_class_name = generic_utils.get_registered_name(
OtherTestClass)
self.assertEqual(fn_class_name, class_name)
cls = generic_utils.get_registered_object(fn_class_name)
self.assertEqual(OtherTestClass, cls)
config = generic_utils.serialize_keras_object(inst)
self.assertEqual(class_name, config['class_name'])
new_inst = generic_utils.deserialize_keras_object(config)
self.assertIsNot(inst, new_inst)
self.assertIsInstance(new_inst, OtherTestClass)
self.assertEqual(5, new_inst._val)
def test_serialize_custom_function(self):
@generic_utils.register_keras_serializable()
def my_fn():
return 42
serialized_name = 'Custom>my_fn'
class_name = generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
self.assertEqual(serialized_name, class_name)
fn_class_name = generic_utils.get_registered_name(my_fn)
self.assertEqual(fn_class_name, class_name)
config = generic_utils.serialize_keras_object(my_fn)
self.assertEqual(class_name, config)
fn = generic_utils.deserialize_keras_object(config)
self.assertEqual(42, fn())
fn_2 = generic_utils.get_registered_object(fn_class_name)
self.assertEqual(42, fn_2())
def test_serialize_custom_class_without_get_config_fails(self):
with self.assertRaisesRegex(
ValueError, 'Cannot register a class that does '
'not have a get_config.*'):
@generic_utils.register_keras_serializable( # pylint: disable=unused-variable
'TestPackage', 'TestClass')
class TestClass(object):
def __init__(self, value):
self._value = value
def test_serializable_object(self):
class SerializableInt(int):
"""A serializable object to pass out of a test layer's config."""
def __new__(cls, value):
return int.__new__(cls, value)
def get_config(self):
return {'value': int(self)}
@classmethod
def from_config(cls, config):
return cls(**config)
layer = keras.layers.Dense(
SerializableInt(3),
activation='relu',
kernel_initializer='ones',
bias_regularizer='l2')
config = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(
config, custom_objects={'SerializableInt': SerializableInt})
self.assertEqual(new_layer.activation, keras.activations.relu)
self.assertEqual(new_layer.bias_regularizer.__class__,
keras.regularizers.L1L2)
self.assertEqual(new_layer.units.__class__, SerializableInt)
self.assertEqual(new_layer.units, 3)
def test_nested_serializable_object(self):
class SerializableInt(int):
"""A serializable object to pass out of a test layer's config."""
def __new__(cls, value):
return int.__new__(cls, value)
def get_config(self):
return {'value': int(self)}
@classmethod
def from_config(cls, config):
return cls(**config)
class SerializableNestedInt(int):
"""A serializable object containing another serializable object."""
def __new__(cls, value, int_obj):
obj = int.__new__(cls, value)
obj.int_obj = int_obj
return obj
def get_config(self):
return {'value': int(self), 'int_obj': self.int_obj}
@classmethod
def from_config(cls, config):
return cls(**config)
nested_int = SerializableInt(4)
layer = keras.layers.Dense(
SerializableNestedInt(3, nested_int),
name='SerializableNestedInt',
activation='relu',
kernel_initializer='ones',
bias_regularizer='l2')
config = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(
config,
custom_objects={
'SerializableInt': SerializableInt,
'SerializableNestedInt': SerializableNestedInt
})
# Make sure the string field doesn't get convert to custom object, even
# they have same value.
self.assertEqual(new_layer.name, 'SerializableNestedInt')
self.assertEqual(new_layer.activation, keras.activations.relu)
self.assertEqual(new_layer.bias_regularizer.__class__,
keras.regularizers.L1L2)
self.assertEqual(new_layer.units.__class__, SerializableNestedInt)
self.assertEqual(new_layer.units, 3)
self.assertEqual(new_layer.units.int_obj.__class__, SerializableInt)
self.assertEqual(new_layer.units.int_obj, 4)
def test_nested_serializable_fn(self):
def serializable_fn(x):
"""A serializable function to pass out of a test layer's config."""
return x
class SerializableNestedInt(int):
"""A serializable object containing a serializable function."""
def __new__(cls, value, fn):
obj = int.__new__(cls, value)
obj.fn = fn
return obj
def get_config(self):
return {'value': int(self), 'fn': self.fn}
@classmethod
def from_config(cls, config):
return cls(**config)
layer = keras.layers.Dense(
SerializableNestedInt(3, serializable_fn),
activation='relu',
kernel_initializer='ones',
bias_regularizer='l2')
config = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(
config,
custom_objects={
'serializable_fn': serializable_fn,
'SerializableNestedInt': SerializableNestedInt
})
self.assertEqual(new_layer.activation, keras.activations.relu)
self.assertIsInstance(new_layer.bias_regularizer, keras.regularizers.L1L2)
self.assertIsInstance(new_layer.units, SerializableNestedInt)
self.assertEqual(new_layer.units, 3)
self.assertIs(new_layer.units.fn, serializable_fn)
class SliceArraysTest(test.TestCase):
def test_slice_arrays(self):
input_a = list([1, 2, 3])
self.assertEqual(
generic_utils.slice_arrays(input_a, start=0),
[None, None, None])
self.assertEqual(
generic_utils.slice_arrays(input_a, stop=3),
[None, None, None])
self.assertEqual(
generic_utils.slice_arrays(input_a, start=0, stop=1),
[None, None, None])
if __name__ == '__main__':
test.main()

View File

@ -1,403 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utilities related to layer/model functionality.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six
from tensorflow.python.frozen_keras import backend as K
from tensorflow.python.frozen_keras.utils.conv_utils import convert_kernel
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
def get_source_inputs(tensor, layer=None, node_index=None):
"""Returns the list of input tensors necessary to compute `tensor`.
Output will always be a list of tensors
(potentially with 1 element).
Arguments:
tensor: The tensor to start from.
layer: Origin layer of the tensor. Will be
determined via tensor._keras_history if not provided.
node_index: Origin node index of the tensor.
Returns:
List of input tensors.
"""
if not hasattr(tensor, '_keras_history'):
return tensor
if layer is None or node_index:
layer, node_index, _ = tensor._keras_history
if not layer._inbound_nodes:
return [tensor]
else:
node = layer._inbound_nodes[node_index]
if not node.inbound_layers:
# Reached an Input layer, stop recursion.
return nest.flatten(node.input_tensors)
else:
source_tensors = []
for layer, node_index, _, tensor in node.iterate_inbound():
previous_sources = get_source_inputs(tensor, layer, node_index)
# Avoid input redundancy.
for x in previous_sources:
if all(x is not t for t in source_tensors):
source_tensors.append(x)
return source_tensors
def validate_string_arg(input_data,
allowable_strings,
layer_name,
arg_name,
allow_none=False,
allow_callables=False):
"""Validates the correctness of a string-based arg."""
if allow_none and input_data is None:
return
elif allow_callables and callable(input_data):
return
elif isinstance(input_data,
six.string_types) and input_data in allowable_strings:
return
else:
allowed_args = '`None`, ' if allow_none else ''
allowed_args += 'a `Callable`, ' if allow_callables else ''
allowed_args += 'or one of the following values: %s' % (allowable_strings,)
raise ValueError(("%s's %s arg received an invalid value %s. " +
'Allowed values are %s.') %
(layer_name, arg_name, input_data, allowed_args))
def count_params(weights):
"""Count the total number of scalars composing the weights.
Arguments:
weights: An iterable containing the weights on which to compute params
Returns:
The total number of scalars composing the weights
"""
unique_weights = object_identity.ObjectIdentitySet(weights)
weight_shapes = [w.shape.as_list() for w in unique_weights]
standardized_weight_shapes = [
[0 if w_i is None else w_i for w_i in w] for w in weight_shapes
]
return int(sum(np.prod(p) for p in standardized_weight_shapes))
def print_summary(model, line_length=None, positions=None, print_fn=None):
"""Prints a summary of a model.
Arguments:
model: Keras model instance.
line_length: Total length of printed lines
(e.g. set this to adapt the display to different
terminal window sizes).
positions: Relative or absolute positions of log elements in each line.
If not provided, defaults to `[.33, .55, .67, 1.]`.
print_fn: Print function to use.
It will be called on each line of the summary.
You can set it to a custom function
in order to capture the string summary.
It defaults to `print` (prints to stdout).
"""
if print_fn is None:
print_fn = print
if model.__class__.__name__ == 'Sequential':
sequential_like = True
elif not model._is_graph_network:
# We treat subclassed models as a simple sequence of layers, for logging
# purposes.
sequential_like = True
else:
sequential_like = True
nodes_by_depth = model._nodes_by_depth.values()
nodes = []
for v in nodes_by_depth:
if (len(v) > 1) or (len(v) == 1 and
len(nest.flatten(v[0].inbound_layers)) > 1):
# if the model has multiple nodes
# or if the nodes have multiple inbound_layers
# the model is no longer sequential
sequential_like = False
break
nodes += v
if sequential_like:
# search for shared layers
for layer in model.layers:
flag = False
for node in layer._inbound_nodes:
if node in nodes:
if flag:
sequential_like = False
break
else:
flag = True
if not sequential_like:
break
if sequential_like:
line_length = line_length or 65
positions = positions or [.45, .85, 1.]
if positions[-1] <= 1:
positions = [int(line_length * p) for p in positions]
# header names for the different log elements
to_display = ['Layer (type)', 'Output Shape', 'Param #']
else:
line_length = line_length or 98
positions = positions or [.33, .55, .67, 1.]
if positions[-1] <= 1:
positions = [int(line_length * p) for p in positions]
# header names for the different log elements
to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
relevant_nodes = []
for v in model._nodes_by_depth.values():
relevant_nodes += v
def print_row(fields, positions):
line = ''
for i in range(len(fields)):
if i > 0:
line = line[:-1] + ' '
line += str(fields[i])
line = line[:positions[i]]
line += ' ' * (positions[i] - len(line))
print_fn(line)
print_fn('Model: "{}"'.format(model.name))
print_fn('_' * line_length)
print_row(to_display, positions)
print_fn('=' * line_length)
def print_layer_summary(layer):
"""Prints a summary for a single layer.
Arguments:
layer: target layer.
"""
try:
output_shape = layer.output_shape
except AttributeError:
output_shape = 'multiple'
except RuntimeError: # output_shape unknown in Eager mode.
output_shape = '?'
name = layer.name
cls_name = layer.__class__.__name__
fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()]
print_row(fields, positions)
def print_layer_summary_with_connections(layer):
"""Prints a summary for a single layer (including topological connections).
Arguments:
layer: target layer.
"""
try:
output_shape = layer.output_shape
except AttributeError:
output_shape = 'multiple'
connections = []
for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
tensor_index))
name = layer.name
cls_name = layer.__class__.__name__
if not connections:
first_connection = ''
else:
first_connection = connections[0]
fields = [
name + ' (' + cls_name + ')', output_shape,
layer.count_params(), first_connection
]
print_row(fields, positions)
if len(connections) > 1:
for i in range(1, len(connections)):
fields = ['', '', '', connections[i]]
print_row(fields, positions)
layers = model.layers
for i in range(len(layers)):
if sequential_like:
print_layer_summary(layers[i])
else:
print_layer_summary_with_connections(layers[i])
if i == len(layers) - 1:
print_fn('=' * line_length)
else:
print_fn('_' * line_length)
if hasattr(model, '_collected_trainable_weights'):
trainable_count = count_params(model._collected_trainable_weights)
else:
trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)
print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
print_fn('Trainable params: {:,}'.format(trainable_count))
print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
print_fn('_' * line_length)
def gather_trainable_weights(trainable, sub_layers, extra_variables):
"""Lists the trainable weights for an object with sub-layers.
Args:
trainable: Whether the object collecting the variables is trainable.
sub_layers: A flat list of Layer objects owned by this object, to collect
variables from.
extra_variables: Any extra variables to include. Their `.trainable` property
is used to categorize them.
Returns:
A list of collected trainable weights/variables.
"""
if not trainable:
return []
weights = []
for layer in sub_layers:
weights += layer.trainable_weights
trainable_extra_variables = [
v for v in extra_variables if v.trainable]
return weights + trainable_extra_variables
def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
"""Lists the non-trainable weights for an object with sub-layers.
Args:
trainable: Whether the object collecting the variables is trainable.
sub_layers: A flat list of Layer objects owned by this object, to collect
variables from.
extra_variables: Any extra variables to include. Their `.trainable` property
is used to categorize them.
Returns:
A list of collected non-trainable weights/variables.
"""
trainable_extra_variables = []
non_trainable_extra_variables = []
for v in extra_variables:
if v.trainable:
trainable_extra_variables.append(v)
else:
non_trainable_extra_variables.append(v)
weights = []
for layer in sub_layers:
weights += layer.non_trainable_weights
if not trainable:
trainable_weights = []
for layer in sub_layers:
trainable_weights += layer.trainable_weights
return (trainable_weights + trainable_extra_variables
+ weights + non_trainable_extra_variables)
return weights + non_trainable_extra_variables
@deprecation.deprecated('2020-06-23',
'The Theano kernel format is legacy; '
'this utility will be removed.')
def convert_all_kernels_in_model(model):
"""Converts all convolution kernels in a model from Theano to TensorFlow.
Also works from TensorFlow to Theano.
This is used for converting legacy Theano-saved model files.
Arguments:
model: target model for the conversion.
"""
# Note: SeparableConvolution not included
# since only supported by TF.
conv_classes = {
'Conv1D',
'Conv2D',
'Conv3D',
'Conv2DTranspose',
}
to_assign = []
for layer in model.layers:
if layer.__class__.__name__ in conv_classes:
original_kernel = K.get_value(layer.kernel)
converted_kernel = convert_kernel(original_kernel)
to_assign.append((layer.kernel, converted_kernel))
K.batch_set_value(to_assign)
def convert_dense_weights_data_format(dense,
previous_feature_map_shape,
target_data_format='channels_first'):
"""Utility useful when changing a convnet's `data_format`.
When porting the weights of a convnet from one data format to the other,
if the convnet includes a `Flatten` layer
(applied to the last convolutional feature map)
followed by a `Dense` layer, the weights of that `Dense` layer
should be updated to reflect the new dimension ordering.
Arguments:
dense: The target `Dense` layer.
previous_feature_map_shape: A shape tuple of 3 integers,
e.g. `(512, 7, 7)`. The shape of the convolutional
feature map right before the `Flatten` layer that
came before the target `Dense` layer.
target_data_format: One of "channels_last", "channels_first".
Set it "channels_last"
if converting a "channels_first" model to "channels_last",
or reciprocally.
"""
assert target_data_format in {'channels_last', 'channels_first'}
kernel, bias = dense.get_weights()
for i in range(kernel.shape[1]):
if target_data_format == 'channels_first':
c, h, w = previous_feature_map_shape
original_fm_shape = (h, w, c)
ki = kernel[:, i].reshape(original_fm_shape)
ki = np.transpose(ki, (2, 0, 1)) # last -> first
else:
h, w, c = previous_feature_map_shape
original_fm_shape = (c, h, w)
ki = kernel[:, i].reshape(original_fm_shape)
ki = np.transpose(ki, (1, 2, 0)) # first -> last
kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
dense.set_weights([kernel, bias])
def is_builtin_layer(layer):
if not getattr(layer, '_keras_api_names', None):
return False
# Subclasses of `Layer` that are not exported inherit the export name
# of the base layer class.
return (layer._keras_api_names != ('keras.layers.Layer',) and
layer._keras_api_names_v1 != ('keras.layers.Layer',))

View File

@ -1,524 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow-related utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
import six
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond as smart_module
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.frozen_keras import backend as K
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
If `pred` is a bool or has a constant value, we return either `true_fn()`
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
Arguments:
pred: A scalar determining whether to return the result of `true_fn` or
`false_fn`.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
name: Optional name prefix when using `tf.cond`.
Returns:
Tensors returned by the call to either `true_fn` or `false_fn`.
Raises:
TypeError: If `true_fn` or `false_fn` is not callable.
"""
if isinstance(pred, variables.Variable):
return control_flow_ops.cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
return smart_module.smart_cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
def constant_value(pred):
"""Return the bool value for `pred`, or None if `pred` had a dynamic value.
Arguments:
pred: A scalar, either a Python bool or a TensorFlow boolean variable
or tensor, or the Python integer 1 or 0.
Returns:
True or False if `pred` has a constant boolean value, None otherwise.
Raises:
TypeError: If `pred` is not a Variable, Tensor or bool, or Python
integer 1 or 0.
"""
# Allow integer booleans.
if isinstance(pred, int):
if pred == 1:
pred = True
elif pred == 0:
pred = False
if isinstance(pred, variables.Variable):
return None
return smart_module.smart_constant_value(pred)
def is_tensor_or_tensor_list(v):
v = nest.flatten(v)
if v and isinstance(v[0], ops.Tensor):
return True
else:
return False
def get_reachable_from_inputs(inputs, targets=None):
"""Returns the set of tensors/ops reachable from `inputs`.
Stops if all targets have been found (target is optional).
Only valid in Symbolic mode, not Eager mode.
Args:
inputs: List of tensors.
targets: List of tensors.
Returns:
A set of tensors reachable from the inputs (includes the inputs themselves).
"""
inputs = nest.flatten(inputs, expand_composites=True)
reachable = object_identity.ObjectIdentitySet(inputs)
if targets:
remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets))
queue = inputs[:]
while queue:
x = queue.pop()
if isinstance(x, tuple(_user_convertible_tensor_types)):
# Can't find consumers of user-specific types.
continue
if isinstance(x, ops.Operation):
outputs = x.outputs[:] or []
outputs += x._control_outputs # pylint: disable=protected-access
elif isinstance(x, variables.Variable):
try:
outputs = [x.op]
except AttributeError:
# Variables can be created in an Eager context.
outputs = []
elif tensor_util.is_tensor(x):
outputs = x.consumers()
else:
raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
for y in outputs:
if y not in reachable:
reachable.add(y)
if targets:
remaining_targets.discard(y)
queue.insert(0, y)
if targets and not remaining_targets:
return reachable
return reachable
# This function needs access to private functions of `nest`.
# pylint: disable=protected-access
def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
"""Maps the atomic elements of a nested structure.
Arguments:
is_atomic_fn: A function that determines if an element of `nested` is
atomic.
map_fn: The function to apply to atomic elements of `nested`.
nested: A nested structure.
Returns:
The nested structure, with atomic elements mapped according to `map_fn`.
Raises:
ValueError: If an element that is neither atomic nor a sequence is
encountered.
"""
if is_atomic_fn(nested):
return map_fn(nested)
# Recursively convert.
if not nest.is_sequence(nested):
raise ValueError(
'Received non-atomic and non-sequence element: {}'.format(nested))
if nest._is_mapping(nested):
values = [nested[k] for k in nest._sorted(nested)]
else:
values = nested
mapped_values = [
map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values
]
return nest._sequence_like(nested, mapped_values)
# pylint: enable=protected-access
def convert_shapes(input_shape, to_tuples=True):
"""Converts nested shape representations to desired format.
Performs:
TensorShapes -> tuples if `to_tuples=True`.
tuples of int or None -> TensorShapes if `to_tuples=False`.
Valid objects to be converted are:
- TensorShapes
- tuples with elements of type int or None.
- ints
- None
Arguments:
input_shape: A nested structure of objects to be converted to TensorShapes.
to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts
all tuples representing shapes to TensorShapes.
Returns:
Nested structure of shapes in desired format.
Raises:
ValueError: when the input tensor shape can't be converted to tuples, eg
unknown tensor shape.
"""
def _is_shape_component(value):
return value is None or isinstance(value, (int, tensor_shape.Dimension))
def _is_atomic_shape(input_shape):
# Ex: TensorShape or (None, 10, 32) or 5 or `None`
if _is_shape_component(input_shape):
return True
if isinstance(input_shape, tensor_shape.TensorShape):
return True
if (isinstance(input_shape, (tuple, list)) and
all(_is_shape_component(ele) for ele in input_shape)):
return True
return False
def _convert_shape(input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if to_tuples:
input_shape = tuple(input_shape.as_list())
return input_shape
return map_structure_with_atomic(_is_atomic_shape, _convert_shape,
input_shape)
class ListWrapper(object):
"""A wrapper for lists to be treated as elements for `nest`."""
def __init__(self, list_to_wrap):
self._list = list_to_wrap
def as_list(self):
return self._list
def convert_inner_node_data(nested, wrap=False):
"""Either wraps or unwraps innermost node data lists in `ListWrapper` objects.
Arguments:
nested: A nested data structure.
wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`,
unwraps `ListWrapper` objects into lists.
Returns:
Structure of same type as nested, with lists wrapped/unwrapped.
"""
def _is_serialized_node_data(nested):
# Node data can be of form `[layer_name, node_id, tensor_id]` or
# `[layer_name, node_id, tensor_id, kwargs]`.
if (isinstance(nested, list) and (len(nested) in [3, 4]) and
isinstance(nested[0], six.string_types)):
return True
return False
def _is_atomic_nested(nested):
"""Returns `True` if `nested` is a list representing node data."""
if isinstance(nested, ListWrapper):
return True
if _is_serialized_node_data(nested):
return True
return not nest.is_sequence(nested)
def _convert_object_or_list(nested):
"""Convert b/t `ListWrapper` object and list representations."""
if wrap:
if isinstance(nested, ListWrapper):
return nested
if _is_serialized_node_data(nested):
return ListWrapper(nested)
return nested
else:
if isinstance(nested, ListWrapper):
return nested.as_list()
return nested
return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list,
nested)
def shape_type_conversion(fn):
"""Decorator that handles tuple/TensorShape conversion.
Used in `compute_output_shape` and `build`.
Arguments:
fn: function to wrap.
Returns:
Wrapped function.
"""
def wrapper(instance, input_shape):
# Pass shapes as tuples to `fn`
# This preserves compatibility with external Keras.
if input_shape is not None:
input_shape = convert_shapes(input_shape, to_tuples=True)
output_shape = fn(instance, input_shape)
# Return shapes from `fn` as TensorShapes.
if output_shape is not None:
output_shape = convert_shapes(output_shape, to_tuples=False)
return output_shape
return wrapper
def are_all_symbolic_tensors(tensors):
return all(is_symbolic_tensor(tensor) for tensor in tensors)
_user_convertible_tensor_types = set()
def is_symbolic_tensor(tensor):
"""Returns whether a tensor is symbolic (from a TF graph) or an eager tensor.
A Variable can be seen as either: it is considered symbolic
when we are in a graph scope, and eager when we are in an eager scope.
Arguments:
tensor: A tensor instance to test.
Returns:
True for symbolic tensors, False for eager tensors.
"""
if isinstance(tensor, tuple(_user_convertible_tensor_types)):
tensor = ops.convert_to_tensor_or_composite(tensor)
if isinstance(tensor, variables.Variable):
# Variables that are output of a Keras Layer in Functional API mode
# should be considered symbolic.
# TODO(omalleyt): We need a better way to check this in order to
# enable `run_eagerly=True` for Models containing Layers that
# return Variables as outputs.
return (getattr(tensor, '_keras_history', False) or
not context.executing_eagerly())
if isinstance(tensor, composite_tensor.CompositeTensor):
component_tensors = nest.flatten(tensor, expand_composites=True)
return any(hasattr(t, 'graph') for t in component_tensors)
if isinstance(tensor, ops.Tensor):
return hasattr(tensor, 'graph')
return False
def register_symbolic_tensor_type(cls):
"""Allows users to specify types regarded as symbolic `Tensor`s.
Used in conjunction with `tf.register_tensor_conversion_function`, calling
`tf.keras.utils.register_symbolic_tensor_type(cls)` allows non-`Tensor`
objects to be plumbed through Keras layers.
Example:
```python
# One-time setup.
class Foo(object):
def __init__(self, input_):
self._input = input_
def value(self):
return tf.constant(42.)
tf.register_tensor_conversion_function(
Foo, lambda x, *args, **kwargs: x.value())
tf.keras.utils.register_symbolic_tensor_type(Foo)
# User-land.
layer = tf.keras.layers.Lambda(lambda input_: Foo(input_))
```
Arguments:
cls: A `class` type which shall be regarded as a symbolic `Tensor`.
"""
global _user_convertible_tensor_types
_user_convertible_tensor_types.add(cls)
def type_spec_from_value(value):
"""Grab type_spec without converting array-likes to tensors."""
if isinstance(value, composite_tensor.CompositeTensor):
return value._type_spec # pylint: disable=protected-access
# Get a TensorSpec for array-like data without
# converting the data to a Tensor
if hasattr(value, 'shape') and hasattr(value, 'dtype'):
return tensor_spec.TensorSpec(value.shape, value.dtype)
else:
return type_spec.type_spec_from_value(value)
def is_tensor_or_variable(x):
return tensor_util.is_tensor(x) or isinstance(x, variables.Variable)
def assert_no_legacy_layers(layers):
"""Prevent tf.layers.Layers from being used with Keras.
Certain legacy layers inherit from their keras analogs; however they are
not supported with keras and can lead to subtle and hard to diagnose bugs.
Args:
layers: A list of layers to check
Raises:
TypeError: If any elements of layers are tf.layers.Layers
"""
# isinstance check for tf.layers.Layer introduces a circular dependency.
legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)]
if legacy_layers:
layer_str = '\n'.join(' ' + str(l) for l in legacy_layers)
raise TypeError(
'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a '
'framework (for instance using the Network, Model, or Sequential '
'classes), please use the tf.keras.layers implementation instead. '
'(Or, if writing custom layers, subclass from tf.keras.layers rather '
'than tf.layers)'.format(layer_str))
@tf_contextlib.contextmanager
def maybe_init_scope(layer):
"""Open an `init_scope` if in V2 mode and using the keras graph.
Arguments:
layer: The Layer/Model that is currently active.
Yields:
None
"""
# Don't open an init_scope in V1 mode or when using legacy tf.layers.
if (ops.executing_eagerly_outside_functions() and
getattr(layer, '_keras_style', True)):
with ops.init_scope():
yield
else:
yield
@tf_contextlib.contextmanager
def graph_context_for_symbolic_tensors(*args, **kwargs):
"""Returns graph context manager if any of the inputs is a symbolic tensor."""
if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())):
with K.get_graph().as_default():
yield
else:
yield
def dataset_is_infinite(dataset):
"""True if the passed dataset is infinite."""
if ops.executing_eagerly_outside_functions():
return math_ops.equal(
cardinality.cardinality(dataset), cardinality.INFINITE)
else:
dataset_size = K.get_session().run(cardinality.cardinality(dataset))
return dataset_size == cardinality.INFINITE
def get_tensor_spec(t, dynamic_batch=False, name=None):
"""Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
if isinstance(t, type_spec.TypeSpec):
spec = t
elif isinstance(t, composite_tensor.CompositeTensor):
# TODO(b/148821952): Should these specs have a name attr?
spec = t._type_spec # pylint: disable=protected-access
elif hasattr(t, 'shape') and hasattr(t, 'dtype'):
spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
else:
return None # Allow non-Tensors to pass through.
if not dynamic_batch:
return spec
dynamic_batch_spec = copy.deepcopy(spec)
# RaggedTensorSpec only has a private _shape.
shape = dynamic_batch_spec._shape.as_list() # pylint: disable=protected-access
if shape:
shape[0] = None
dynamic_batch_spec._shape = tensor_shape.TensorShape(shape) # pylint: disable=protected-access
return dynamic_batch_spec
def to_numpy_or_python_type(tensors):
"""Converts a structure of `Tensor`s to `NumPy` arrays or Python scalar types.
For each tensor, it calls `tensor.numpy()`. If the result is a scalar value,
it converts it to a Python type, such as a float or int, by calling
`result.item()`.
Numpy scalars are converted, as Python types are often more convenient to deal
with. This is especially useful for bfloat16 Numpy scalars, which don't
support as many operations as other Numpy values.
Args:
tensors: A structure of tensors.
Returns:
`tensors`, but scalar tensors are converted to Python types and non-scalar
tensors are converted to Numpy arrays.
"""
def _to_single_numpy_or_python_type(t):
if isinstance(t, ops.Tensor):
x = t.numpy()
return x.item() if np.ndim(x) == 0 else x
return t # Don't turn ragged or sparse tensors to NumPy.
return nest.map_structure(_to_single_numpy_or_python_type, tensors)

View File

@ -1,162 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras TF utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.frozen_keras.utils import tf_utils
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class TestIsSymbolicTensor(test.TestCase):
def test_default_behavior(self):
if context.executing_eagerly():
self.assertFalse(tf_utils.is_symbolic_tensor(
variables.Variable(name='blah', initial_value=0.)))
self.assertFalse(
tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
self.assertFalse(tf_utils.is_symbolic_tensor(
sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
else:
self.assertTrue(tf_utils.is_symbolic_tensor(
variables.Variable(name='blah', initial_value=0.)))
self.assertTrue(tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
self.assertTrue(tf_utils.is_symbolic_tensor(
sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
def test_works_with_registered(self):
class CustomClass(object):
def value(self):
return ops.convert_to_tensor_v2(42.)
ops.register_tensor_conversion_function(
CustomClass, lambda value, **_: value.value())
tf_utils.register_symbolic_tensor_type(CustomClass)
if context.executing_eagerly():
self.assertFalse(tf_utils.is_symbolic_tensor(
variables.Variable(name='blah', initial_value=0.)))
self.assertFalse(
tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
self.assertFalse(tf_utils.is_symbolic_tensor(
sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
self.assertFalse(tf_utils.is_symbolic_tensor(CustomClass()))
else:
self.assertTrue(tf_utils.is_symbolic_tensor(
variables.Variable(name='blah', initial_value=0.)))
self.assertTrue(tf_utils.is_symbolic_tensor(ops.convert_to_tensor_v2(0.)))
self.assertTrue(tf_utils.is_symbolic_tensor(
sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
self.assertTrue(tf_utils.is_symbolic_tensor(CustomClass()))
def test_enables_nontensor_plumbing(self):
self.skipTest('Sequential model will check layer instance type and fail.')
if context.executing_eagerly():
self.skipTest('`compile` functionality changed.')
# Setup.
class Foo(object):
def __init__(self, input_):
self._input = input_
self.value = ops.convert_to_tensor_v2([[42.]])
@property
def dtype(self):
return self.value.dtype
ops.register_tensor_conversion_function(
Foo, lambda x, *args, **kwargs: x.value)
tf_utils.register_symbolic_tensor_type(Foo)
class PlumbingLayer(keras.layers.Lambda):
def __init__(self, fn, **kwargs):
def _fn(*fargs, **fkwargs):
d = fn(*fargs, **fkwargs)
x = ops.convert_to_tensor_v2(d)
d.shape = x.shape
d.get_shape = x.get_shape
return d, x
super(PlumbingLayer, self).__init__(_fn, **kwargs)
self._enter_dunder_call = False
def __call__(self, inputs, *args, **kwargs):
self._enter_dunder_call = True
d, _ = super(PlumbingLayer, self).__call__(inputs, *args, **kwargs)
self._enter_dunder_call = False
return d
def call(self, inputs, *args, **kwargs):
d, v = super(PlumbingLayer, self).call(inputs, *args, **kwargs)
if self._enter_dunder_call:
return d, v
return d
# User-land.
model = keras.Sequential([
keras.layers.InputLayer((1,)),
PlumbingLayer(Foo), # Makes a `Foo` object.
])
# Let's ensure Keras graph history is preserved by composing the models.
model = keras.Model(model.inputs, model(model.outputs))
# Now we instantiate the model and verify we have a `Foo` object, not a
# `Tensor`.
y = model(ops.convert_to_tensor_v2([[7.]]))
self.assertIsInstance(y, Foo)
# Confirm that (custom) loss sees `Foo` instance, not Tensor.
obtained_prediction_box = [None]
def custom_loss(y_obs, y_pred):
del y_obs
obtained_prediction_box[0] = y_pred
return y_pred
# Apparently `compile` calls the loss function enough to trigger the
# side-effect.
model.compile('SGD', loss=custom_loss)
self.assertIsInstance(obtained_prediction_box[0], Foo)
class ConvertInnerNodeDataTest(test.TestCase):
def test_convert_inner_node_data(self):
data = tf_utils.convert_inner_node_data((tf_utils.ListWrapper(['l', 2, 3]),
tf_utils.ListWrapper(['l', 5, 6])))
self.assertEqual(data, (['l', 2, 3], ['l', 5, 6]))
data = tf_utils.convert_inner_node_data(((['l', 2, 3], ['l', 5, 6])),
wrap=True)
self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data))
if __name__ == '__main__':
test.main()

View File

@ -44,7 +44,6 @@ py_library(
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/eager:monitoring",
"//tensorflow/python/frozen_keras/engine:legacy_base_layer",
"//tensorflow/python/keras:activations",
"//tensorflow/python/keras:backend",
"//tensorflow/python/keras:callbacks",

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import copy
from tensorflow.python.frozen_keras.engine import legacy_base_layer
from tensorflow.python.keras import layers as layer_module
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import input_layer
@ -169,8 +168,7 @@ class Sequential(training.Model):
if isinstance(origin_layer, input_layer.InputLayer):
layer = origin_layer
if not isinstance(layer,
(base_layer.Layer, legacy_base_layer.LegacyBaseLayer)):
if not isinstance(layer, base_layer.Layer):
raise TypeError('The added layer must be '
'an instance of class Layer. '
'Found: ' + str(layer))

View File

@ -122,7 +122,6 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/eager:eager_pip",
"//tensorflow/python/frozen_keras/engine:legacy_base_layer",
"//tensorflow/python/keras:combinations",
"//tensorflow/python/keras/layers/preprocessing:preprocessing_test_utils",
"//tensorflow/python/keras/distribute:distribute_strategy_test_lib",