Copy constraint/initializer/regularizers to frozen_keras.
This is used by legacy_base_layer and other layer subclasses. PiperOrigin-RevId: 299410923 Change-Id: I8011fca38347bfb30956d8560ce6c144fd6069b4
This commit is contained in:
parent
e3c18e7beb
commit
6f795b7539
55
tensorflow/python/frozen_keras/BUILD
Normal file
55
tensorflow/python/frozen_keras/BUILD
Normal file
@ -0,0 +1,55 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
#TODO(scottzhu): Cleanup all the deps to python/keras
|
||||
|
||||
py_library(
|
||||
name = "frozen_keras",
|
||||
deps = [
|
||||
":constraint",
|
||||
":initializers",
|
||||
":regularizers",
|
||||
"//tensorflow/python/frozen_keras/engine:legacy_base_layer",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "constraint",
|
||||
srcs = ["constraints.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "initializers",
|
||||
srcs = ["initializers.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:init_ops_v2",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "regularizers",
|
||||
srcs = ["regularizers.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
282
tensorflow/python/frozen_keras/constraints.py
Normal file
282
tensorflow/python/frozen_keras/constraints.py
Normal file
@ -0,0 +1,282 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=invalid-name
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
"""Constraints: functions that impose constraints on weight values."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class Constraint(object):
|
||||
|
||||
def __call__(self, w):
|
||||
return w
|
||||
|
||||
def get_config(self):
|
||||
return {}
|
||||
|
||||
|
||||
class MaxNorm(Constraint):
|
||||
"""MaxNorm weight constraint.
|
||||
|
||||
Constrains the weights incident to each hidden unit
|
||||
to have a norm less than or equal to a desired value.
|
||||
|
||||
Arguments:
|
||||
m: the maximum norm for the incoming weights.
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, max_value=2, axis=0):
|
||||
self.max_value = max_value
|
||||
self.axis = axis
|
||||
|
||||
def __call__(self, w):
|
||||
norms = K.sqrt(
|
||||
math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
|
||||
desired = K.clip(norms, 0, self.max_value)
|
||||
return w * (desired / (K.epsilon() + norms))
|
||||
|
||||
def get_config(self):
|
||||
return {'max_value': self.max_value, 'axis': self.axis}
|
||||
|
||||
|
||||
class NonNeg(Constraint):
|
||||
"""Constrains the weights to be non-negative.
|
||||
"""
|
||||
|
||||
def __call__(self, w):
|
||||
return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx())
|
||||
|
||||
|
||||
class UnitNorm(Constraint):
|
||||
"""Constrains the weights incident to each hidden unit to have unit norm.
|
||||
|
||||
Arguments:
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
"""
|
||||
|
||||
def __init__(self, axis=0):
|
||||
self.axis = axis
|
||||
|
||||
def __call__(self, w):
|
||||
return w / (
|
||||
K.epsilon() + K.sqrt(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.square(w), axis=self.axis, keepdims=True)))
|
||||
|
||||
def get_config(self):
|
||||
return {'axis': self.axis}
|
||||
|
||||
|
||||
class MinMaxNorm(Constraint):
|
||||
"""MinMaxNorm weight constraint.
|
||||
|
||||
Constrains the weights incident to each hidden unit
|
||||
to have the norm between a lower bound and an upper bound.
|
||||
|
||||
Arguments:
|
||||
min_value: the minimum norm for the incoming weights.
|
||||
max_value: the maximum norm for the incoming weights.
|
||||
rate: rate for enforcing the constraint: weights will be
|
||||
rescaled to yield
|
||||
`(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
|
||||
Effectively, this means that rate=1.0 stands for strict
|
||||
enforcement of the constraint, while rate<1.0 means that
|
||||
weights will be rescaled at each step to slowly move
|
||||
towards a value inside the desired interval.
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
"""
|
||||
|
||||
def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.rate = rate
|
||||
self.axis = axis
|
||||
|
||||
def __call__(self, w):
|
||||
norms = K.sqrt(
|
||||
math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
|
||||
desired = (
|
||||
self.rate * K.clip(norms, self.min_value, self.max_value) +
|
||||
(1 - self.rate) * norms)
|
||||
return w * (desired / (K.epsilon() + norms))
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'min_value': self.min_value,
|
||||
'max_value': self.max_value,
|
||||
'rate': self.rate,
|
||||
'axis': self.axis
|
||||
}
|
||||
|
||||
|
||||
class RadialConstraint(Constraint):
|
||||
"""Constrains `Conv2D` kernel weights to be the same for each radius.
|
||||
|
||||
For example, the desired output for the following 4-by-4 kernel::
|
||||
|
||||
```
|
||||
kernel = [[v_00, v_01, v_02, v_03],
|
||||
[v_10, v_11, v_12, v_13],
|
||||
[v_20, v_21, v_22, v_23],
|
||||
[v_30, v_31, v_32, v_33]]
|
||||
```
|
||||
|
||||
is this::
|
||||
|
||||
```
|
||||
kernel = [[v_11, v_11, v_11, v_11],
|
||||
[v_11, v_33, v_33, v_11],
|
||||
[v_11, v_33, v_33, v_11],
|
||||
[v_11, v_11, v_11, v_11]]
|
||||
```
|
||||
|
||||
This constraint can be applied to any `Conv2D` layer version, including
|
||||
`Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
|
||||
`"channels_first"` data format. The method assumes the weight tensor is of
|
||||
shape `(rows, cols, input_depth, output_depth)`.
|
||||
"""
|
||||
|
||||
def __call__(self, w):
|
||||
w_shape = w.shape
|
||||
if w_shape.rank is None or w_shape.rank != 4:
|
||||
raise ValueError(
|
||||
'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
|
||||
|
||||
height, width, channels, kernels = w_shape
|
||||
w = K.reshape(w, (height, width, channels * kernels))
|
||||
# TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch
|
||||
# is supported.
|
||||
w = K.map_fn(
|
||||
self._kernel_constraint,
|
||||
K.stack(array_ops.unstack(w, axis=-1), axis=0))
|
||||
return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1),
|
||||
(height, width, channels, kernels))
|
||||
|
||||
def _kernel_constraint(self, kernel):
|
||||
"""Radially constraints a kernel with shape (height, width, channels)."""
|
||||
padding = K.constant([[1, 1], [1, 1]], dtype='int32')
|
||||
|
||||
kernel_shape = K.shape(kernel)[0]
|
||||
start = K.cast(kernel_shape / 2, 'int32')
|
||||
|
||||
kernel_new = K.switch(
|
||||
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
|
||||
lambda: kernel[start - 1:start, start - 1:start],
|
||||
lambda: kernel[start - 1:start, start - 1:start] + K.zeros( # pylint: disable=g-long-lambda
|
||||
(2, 2), dtype=kernel.dtype))
|
||||
index = K.switch(
|
||||
K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
|
||||
lambda: K.constant(0, dtype='int32'),
|
||||
lambda: K.constant(1, dtype='int32'))
|
||||
while_condition = lambda index, *args: K.less(index, start)
|
||||
|
||||
def body_fn(i, array):
|
||||
return i + 1, array_ops.pad(
|
||||
array,
|
||||
padding,
|
||||
constant_values=kernel[start + i, start + i])
|
||||
|
||||
_, kernel_new = control_flow_ops.while_loop(
|
||||
while_condition,
|
||||
body_fn,
|
||||
[index, kernel_new],
|
||||
shape_invariants=[index.get_shape(),
|
||||
tensor_shape.TensorShape([None, None])])
|
||||
return kernel_new
|
||||
|
||||
|
||||
# Aliases.
|
||||
|
||||
max_norm = MaxNorm
|
||||
non_neg = NonNeg
|
||||
unit_norm = UnitNorm
|
||||
min_max_norm = MinMaxNorm
|
||||
radial_constraint = RadialConstraint
|
||||
|
||||
# Legacy aliases.
|
||||
maxnorm = max_norm
|
||||
nonneg = non_neg
|
||||
unitnorm = unit_norm
|
||||
|
||||
|
||||
def serialize(constraint):
|
||||
return serialize_keras_object(constraint)
|
||||
|
||||
|
||||
def deserialize(config, custom_objects=None):
|
||||
return deserialize_keras_object(
|
||||
config,
|
||||
module_objects=globals(),
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='constraint')
|
||||
|
||||
|
||||
def get(identifier):
|
||||
if identifier is None:
|
||||
return None
|
||||
if isinstance(identifier, dict):
|
||||
return deserialize(identifier)
|
||||
elif isinstance(identifier, six.string_types):
|
||||
config = {'class_name': str(identifier), 'config': {}}
|
||||
return deserialize(config)
|
||||
elif callable(identifier):
|
||||
return identifier
|
||||
else:
|
||||
raise ValueError('Could not interpret constraint identifier: ' +
|
||||
str(identifier))
|
@ -5,6 +5,8 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
#TODO(scottzhu): Cleanup all the deps to python/keras
|
||||
|
||||
py_library(
|
||||
name = "legacy_base_layer",
|
||||
srcs = ["legacy_base_layer.py"],
|
||||
@ -29,11 +31,11 @@ py_library(
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:execute",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/frozen_keras:constraint",
|
||||
"//tensorflow/python/frozen_keras:initializers",
|
||||
"//tensorflow/python/frozen_keras:regularizers",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:constraints",
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras:metrics",
|
||||
"//tensorflow/python/keras:regularizers",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/engine:base_layer_utils",
|
||||
"//tensorflow/python/keras/engine:input_spec",
|
||||
|
@ -51,10 +51,10 @@ from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.frozen_keras import constraints
|
||||
from tensorflow.python.frozen_keras import initializers
|
||||
from tensorflow.python.frozen_keras import regularizers
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import constraints
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.engine import node as node_module
|
||||
|
198
tensorflow/python/frozen_keras/initializers.py
Normal file
198
tensorflow/python/frozen_keras/initializers.py
Normal file
@ -0,0 +1,198 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras initializer serialization / deserialization."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
|
||||
# These imports are brought in so that keras.initializers.deserialize
|
||||
# has them available in module_objects.
|
||||
from tensorflow.python.ops.init_ops import Constant
|
||||
from tensorflow.python.ops.init_ops import GlorotNormal
|
||||
from tensorflow.python.ops.init_ops import GlorotUniform
|
||||
from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import Identity
|
||||
from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import Ones
|
||||
from tensorflow.python.ops.init_ops import Orthogonal
|
||||
from tensorflow.python.ops.init_ops import RandomNormal as TFRandomNormal
|
||||
from tensorflow.python.ops.init_ops import RandomUniform as TFRandomUniform
|
||||
from tensorflow.python.ops.init_ops import TruncatedNormal as TFTruncatedNormal
|
||||
from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.init_ops import Zeros
|
||||
# pylint: disable=unused-import, disable=line-too-long
|
||||
from tensorflow.python.ops.init_ops_v2 import Constant as ConstantV2
|
||||
from tensorflow.python.ops.init_ops_v2 import GlorotNormal as GlorotNormalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import GlorotUniform as GlorotUniformV2
|
||||
from tensorflow.python.ops.init_ops_v2 import he_normal as he_normalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import he_uniform as he_uniformV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Identity as IdentityV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2
|
||||
from tensorflow.python.ops.init_ops_v2 import lecun_normal as lecun_normalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import lecun_uniform as lecun_uniformV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Ones as OnesV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Orthogonal as OrthogonalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import RandomNormal as RandomNormalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import RandomUniform as RandomUniformV2
|
||||
from tensorflow.python.ops.init_ops_v2 import TruncatedNormal as TruncatedNormalV2
|
||||
from tensorflow.python.ops.init_ops_v2 import VarianceScaling as VarianceScalingV2
|
||||
from tensorflow.python.ops.init_ops_v2 import Zeros as ZerosV2
|
||||
# pylint: enable=unused-import, enable=line-too-long
|
||||
|
||||
|
||||
class TruncatedNormal(TFTruncatedNormal):
|
||||
"""Initializer that generates a truncated normal distribution.
|
||||
|
||||
These values are similar to values from a `random_normal_initializer`
|
||||
except that values more than two standard deviations from the mean
|
||||
are discarded and re-drawn. This is the recommended initializer for
|
||||
neural network weights and filters.
|
||||
|
||||
Args:
|
||||
mean: a python scalar or a scalar tensor. Mean of the random values to
|
||||
generate. Defaults to 0.
|
||||
stddev: a python scalar or a scalar tensor. Standard deviation of the random
|
||||
values to generate. Defaults to 0.05.
|
||||
seed: A Python integer. Used to create random seeds. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
dtype: The data type. Only floating point types are supported.
|
||||
|
||||
Returns:
|
||||
A TruncatedNormal instance.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
|
||||
super(TruncatedNormal, self).__init__(
|
||||
mean=mean, stddev=stddev, seed=seed, dtype=dtype)
|
||||
|
||||
|
||||
class RandomUniform(TFRandomUniform):
|
||||
"""Initializer that generates tensors with a uniform distribution.
|
||||
|
||||
Args:
|
||||
minval: A python scalar or a scalar tensor. Lower bound of the range of
|
||||
random values to generate. Defaults to -0.05.
|
||||
maxval: A python scalar or a scalar tensor. Upper bound of the range of
|
||||
random values to generate. Defaults to 0.05.
|
||||
seed: A Python integer. Used to create random seeds. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
dtype: The data type.
|
||||
|
||||
Returns:
|
||||
A RandomUniform instance.
|
||||
"""
|
||||
|
||||
def __init__(self, minval=-0.05, maxval=0.05, seed=None,
|
||||
dtype=dtypes.float32):
|
||||
super(RandomUniform, self).__init__(
|
||||
minval=minval, maxval=maxval, seed=seed, dtype=dtype)
|
||||
|
||||
|
||||
class RandomNormal(TFRandomNormal):
|
||||
"""Initializer that generates tensors with a normal distribution.
|
||||
|
||||
Args:
|
||||
mean: a python scalar or a scalar tensor. Mean of the random values to
|
||||
generate. Defaults to 0.
|
||||
stddev: a python scalar or a scalar tensor. Standard deviation of the random
|
||||
values to generate. Defaults to 0.05.
|
||||
seed: A Python integer. Used to create random seeds. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
dtype: The data type. Only floating point types are supported.
|
||||
|
||||
Returns:
|
||||
RandomNormal instance.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
|
||||
super(RandomNormal, self).__init__(
|
||||
mean=mean, stddev=stddev, seed=seed, dtype=dtype)
|
||||
|
||||
|
||||
# Compatibility aliases
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
zero = zeros = Zeros
|
||||
one = ones = Ones
|
||||
constant = Constant
|
||||
uniform = random_uniform = RandomUniform
|
||||
normal = random_normal = RandomNormal
|
||||
truncated_normal = TruncatedNormal
|
||||
identity = Identity
|
||||
orthogonal = Orthogonal
|
||||
glorot_normal = GlorotNormal
|
||||
glorot_uniform = GlorotUniform
|
||||
|
||||
|
||||
# Utility functions
|
||||
|
||||
|
||||
def serialize(initializer):
|
||||
return serialize_keras_object(initializer)
|
||||
|
||||
|
||||
def deserialize(config, custom_objects=None):
|
||||
"""Return an `Initializer` object from its config."""
|
||||
if tf2.enabled():
|
||||
# Class names are the same for V1 and V2 but the V2 classes
|
||||
# are aliased in this file so we need to grab them directly
|
||||
# from `init_ops_v2`.
|
||||
module_objects = {
|
||||
obj_name: getattr(init_ops_v2, obj_name)
|
||||
for obj_name in dir(init_ops_v2)
|
||||
}
|
||||
else:
|
||||
module_objects = globals()
|
||||
return deserialize_keras_object(
|
||||
config,
|
||||
module_objects=module_objects,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='initializer')
|
||||
|
||||
|
||||
def get(identifier):
|
||||
if identifier is None:
|
||||
return None
|
||||
if isinstance(identifier, dict):
|
||||
return deserialize(identifier)
|
||||
elif isinstance(identifier, six.string_types):
|
||||
identifier = str(identifier)
|
||||
# We have to special-case functions that return classes.
|
||||
# TODO(omalleyt): Turn these into classes or class aliases.
|
||||
special_cases = ['he_normal', 'he_uniform', 'lecun_normal', 'lecun_uniform']
|
||||
if identifier in special_cases:
|
||||
# Treat like a class.
|
||||
return deserialize({'class_name': identifier, 'config': {}})
|
||||
return deserialize(identifier)
|
||||
elif callable(identifier):
|
||||
return identifier
|
||||
else:
|
||||
raise ValueError('Could not interpret initializer identifier: ' +
|
||||
str(identifier))
|
||||
|
||||
|
||||
# pylint: enable=invalid-name
|
306
tensorflow/python/frozen_keras/regularizers.py
Normal file
306
tensorflow/python/frozen_keras/regularizers.py
Normal file
@ -0,0 +1,306 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Built-in regularizers."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class Regularizer(object):
|
||||
"""Regularizer base class.
|
||||
|
||||
Regularizers allow you to apply penalties on layer parameters or layer
|
||||
activity during optimization. These penalties are summed into the loss
|
||||
function that the network optimizes.
|
||||
|
||||
Regularization penalties are applied on a per-layer basis. The exact API will
|
||||
depend on the layer, but many layers (e.g. `Dense`, `Conv1D`, `Conv2D` and
|
||||
`Conv3D`) have a unified API.
|
||||
|
||||
These layers expose 3 keyword arguments:
|
||||
|
||||
- `kernel_regularizer`: Regularizer to apply a penalty on the layer's kernel
|
||||
- `bias_regularizer`: Regularizer to apply a penalty on the layer's bias
|
||||
- `activity_regularizer`: Regularizer to apply a penalty on the layer's output
|
||||
|
||||
All layers (including custom layers) expose `activity_regularizer` as a
|
||||
settable property, whether or not it is in the constructor arguments.
|
||||
|
||||
The value returned by the `activity_regularizer` is divided by the input
|
||||
batch size so that the relative weighting between the weight regularizers and
|
||||
the activity regularizers does not change with the batch size.
|
||||
|
||||
You can access a layer's regularization penalties by calling `layer.losses`
|
||||
after calling the layer on inputs.
|
||||
|
||||
## Example
|
||||
|
||||
>>> layer = tf.keras.layers.Dense(
|
||||
... 5, input_dim=5,
|
||||
... kernel_initializer='ones',
|
||||
... kernel_regularizer=tf.keras.regularizers.l1(0.01),
|
||||
... activity_regularizer=tf.keras.regularizers.l2(0.01))
|
||||
>>> tensor = tf.ones(shape=(5, 5)) * 2.0
|
||||
>>> out = layer(tensor)
|
||||
|
||||
>>> # The kernel regularization term is 0.25
|
||||
>>> # The activity regularization term (after dividing by the batch size) is 5
|
||||
>>> tf.math.reduce_sum(layer.losses)
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=5.25>
|
||||
|
||||
## Available penalties
|
||||
|
||||
```python
|
||||
tf.keras.regularizers.l1(0.3) # L1 Regularization Penalty
|
||||
tf.keras.regularizers.l2(0.1) # L2 Regularization Penalty
|
||||
tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01) # L1 + L2 penalties
|
||||
```
|
||||
|
||||
## Directly calling a regularizer
|
||||
|
||||
Compute a regularization loss on a tensor by directly calling a regularizer
|
||||
as if it is a one-argument function.
|
||||
|
||||
E.g.
|
||||
>>> regularizer = tf.keras.regularizers.l2(2.)
|
||||
>>> tensor = tf.ones(shape=(5, 5))
|
||||
>>> regularizer(tensor)
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=50.0>
|
||||
|
||||
|
||||
## Developing new regularizers
|
||||
|
||||
Any function that takes in a weight matrix and returns a scalar
|
||||
tensor can be used as a regularizer, e.g.:
|
||||
|
||||
>>> @tf.keras.utils.register_keras_serializable(package='Custom', name='l1')
|
||||
... def l1_reg(weight_matrix):
|
||||
... return 0.01 * tf.math.reduce_sum(tf.math.abs(weight_matrix))
|
||||
...
|
||||
>>> layer = tf.keras.layers.Dense(5, input_dim=5,
|
||||
... kernel_initializer='ones', kernel_regularizer=l1_reg)
|
||||
>>> tensor = tf.ones(shape=(5, 5))
|
||||
>>> out = layer(tensor)
|
||||
>>> layer.losses
|
||||
[<tf.Tensor: shape=(), dtype=float32, numpy=0.25>]
|
||||
|
||||
Alternatively, you can write your custom regularizers in an
|
||||
object-oriented way by extending this regularizer base class, e.g.:
|
||||
|
||||
>>> @tf.keras.utils.register_keras_serializable(package='Custom', name='l2')
|
||||
... class L2Regularizer(tf.keras.regularizers.Regularizer):
|
||||
... def __init__(self, l2=0.): # pylint: disable=redefined-outer-name
|
||||
... self.l2 = l2
|
||||
...
|
||||
... def __call__(self, x):
|
||||
... return self.l2 * tf.math.reduce_sum(tf.math.square(x))
|
||||
...
|
||||
... def get_config(self):
|
||||
... return {'l2': float(self.l2)}
|
||||
...
|
||||
>>> layer = tf.keras.layers.Dense(
|
||||
... 5, input_dim=5, kernel_initializer='ones',
|
||||
... kernel_regularizer=L2Regularizer(l2=0.5))
|
||||
|
||||
>>> tensor = tf.ones(shape=(5, 5))
|
||||
>>> out = layer(tensor)
|
||||
>>> layer.losses
|
||||
[<tf.Tensor: shape=(), dtype=float32, numpy=12.5>]
|
||||
|
||||
### A note on serialization and deserialization:
|
||||
|
||||
Registering the regularizers as serializable is optional if you are just
|
||||
training and executing models, exporting to and from SavedModels, or saving
|
||||
and loading weight checkpoints.
|
||||
|
||||
Registration is required for Keras `model_to_estimator`, saving and
|
||||
loading models to HDF5 formats, Keras model cloning, some visualization
|
||||
utilities, and exporting models to and from JSON. If using this functionality,
|
||||
you must make sure any python process running your model has also defined
|
||||
and registered your custom regularizer.
|
||||
|
||||
`tf.keras.utils.register_keras_serializable` is only available in TF 2.1 and
|
||||
beyond. In earlier versions of TensorFlow you must pass your custom
|
||||
regularizer to the `custom_objects` argument of methods that expect custom
|
||||
regularizers to be registered as serializable.
|
||||
"""
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute a regularization penalty from an input tensor."""
|
||||
return 0.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Creates a regularizer from its config.
|
||||
|
||||
This method is the reverse of `get_config`,
|
||||
capable of instantiating the same regularizer from the config
|
||||
dictionary.
|
||||
|
||||
This method is used by Keras `model_to_estimator`, saving and
|
||||
loading models to HDF5 formats, Keras model cloning, some visualization
|
||||
utilities, and exporting models to and from JSON.
|
||||
|
||||
Arguments:
|
||||
config: A Python dictionary, typically the output of get_config.
|
||||
|
||||
Returns:
|
||||
A regularizer instance.
|
||||
"""
|
||||
return cls(**config)
|
||||
|
||||
def get_config(self):
|
||||
"""Returns the config of the regularizer.
|
||||
|
||||
An regularizer config is a Python dictionary (serializable)
|
||||
containing all configuration parameters of the regularizer.
|
||||
The same regularizer can be reinstantiated later
|
||||
(without any saved state) from this configuration.
|
||||
|
||||
This method is optional if you are just training and executing models,
|
||||
exporting to and from SavedModels, or using weight checkpoints.
|
||||
|
||||
This method is required for Keras `model_to_estimator`, saving and
|
||||
loading models to HDF5 formats, Keras model cloning, some visualization
|
||||
utilities, and exporting models to and from JSON.
|
||||
|
||||
Returns:
|
||||
Python dictionary.
|
||||
"""
|
||||
raise NotImplementedError(str(self) + ' does not implement get_config()')
|
||||
|
||||
|
||||
class L1L2(Regularizer):
|
||||
r"""A regularizer that applies both L1 and L2 regularization penalties.
|
||||
|
||||
The L1 regularization penalty is computed as:
|
||||
$$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$
|
||||
|
||||
The L2 regularization penalty is computed as
|
||||
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
|
||||
|
||||
Attributes:
|
||||
l1: Float; L1 regularization factor.
|
||||
l2: Float; L2 regularization factor.
|
||||
"""
|
||||
|
||||
def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name
|
||||
self.l1 = K.cast_to_floatx(l1)
|
||||
self.l2 = K.cast_to_floatx(l2)
|
||||
|
||||
def __call__(self, x):
|
||||
if not self.l1 and not self.l2:
|
||||
return K.constant(0.)
|
||||
regularization = 0.
|
||||
if self.l1:
|
||||
regularization += self.l1 * math_ops.reduce_sum(math_ops.abs(x))
|
||||
if self.l2:
|
||||
regularization += self.l2 * math_ops.reduce_sum(math_ops.square(x))
|
||||
return regularization
|
||||
|
||||
def get_config(self):
|
||||
return {'l1': float(self.l1), 'l2': float(self.l2)}
|
||||
|
||||
|
||||
# Aliases.
|
||||
|
||||
|
||||
def l1(l=0.01):
|
||||
r"""Create a regularizer that applies an L1 regularization penalty.
|
||||
|
||||
The L1 regularization penalty is computed as:
|
||||
$$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$
|
||||
|
||||
Arguments:
|
||||
l: Float; L1 regularization factor.
|
||||
|
||||
Returns:
|
||||
An L1 Regularizer with the given regularization factor.
|
||||
"""
|
||||
return L1L2(l1=l)
|
||||
|
||||
|
||||
def l2(l=0.01):
|
||||
r"""Create a regularizer that applies an L2 regularization penalty.
|
||||
|
||||
The L2 regularization penalty is computed as:
|
||||
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
|
||||
|
||||
Arguments:
|
||||
l: Float; L2 regularization factor.
|
||||
|
||||
Returns:
|
||||
An L2 Regularizer with the given regularization factor.
|
||||
"""
|
||||
return L1L2(l2=l)
|
||||
|
||||
|
||||
def l1_l2(l1=0.01, l2=0.01): # pylint: disable=redefined-outer-name
|
||||
r"""Create a regularizer that applies both L1 and L2 penalties.
|
||||
|
||||
The L1 regularization penalty is computed as:
|
||||
$$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$
|
||||
|
||||
The L2 regularization penalty is computed as:
|
||||
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
|
||||
|
||||
Arguments:
|
||||
l1: Float; L1 regularization factor.
|
||||
l2: Float; L2 regularization factor.
|
||||
|
||||
Returns:
|
||||
An L1L2 Regularizer with the given regularization factors.
|
||||
"""
|
||||
return L1L2(l1=l1, l2=l2)
|
||||
|
||||
|
||||
def serialize(regularizer):
|
||||
return serialize_keras_object(regularizer)
|
||||
|
||||
|
||||
def deserialize(config, custom_objects=None):
|
||||
return deserialize_keras_object(
|
||||
config,
|
||||
module_objects=globals(),
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='regularizer')
|
||||
|
||||
|
||||
def get(identifier):
|
||||
if identifier is None:
|
||||
return None
|
||||
if isinstance(identifier, dict):
|
||||
return deserialize(identifier)
|
||||
elif isinstance(identifier, six.string_types):
|
||||
identifier = str(identifier)
|
||||
# We have to special-case functions that return classes.
|
||||
# TODO(omalleyt): Turn these into classes or class aliases.
|
||||
special_cases = ['l1', 'l2', 'l1_l2']
|
||||
if identifier in special_cases:
|
||||
# Treat like a class.
|
||||
return deserialize({'class_name': identifier, 'config': {}})
|
||||
return deserialize(str(identifier))
|
||||
elif callable(identifier):
|
||||
return identifier
|
||||
else:
|
||||
raise ValueError('Could not interpret regularizer identifier:', identifier)
|
Loading…
x
Reference in New Issue
Block a user