Improve keras.constraints docstrings.
PiperOrigin-RevId: 302458247 Change-Id: I96722e54807ea3b8a4358b9552ec0233f8399707
This commit is contained in:
parent
506287350f
commit
00ad1522c8
@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
|
||||
@keras_export('keras.constraints.Constraint')
|
||||
@ -48,19 +49,21 @@ class MaxNorm(Constraint):
|
||||
Constrains the weights incident to each hidden unit
|
||||
to have a norm less than or equal to a desired value.
|
||||
|
||||
Also available via the shortcut function `tf.keras.constraints.max_norm`.
|
||||
|
||||
Arguments:
|
||||
m: the maximum norm for the incoming weights.
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
max_value: the maximum norm value for the incoming weights.
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
|
||||
"""
|
||||
|
||||
@ -68,12 +71,14 @@ class MaxNorm(Constraint):
|
||||
self.max_value = max_value
|
||||
self.axis = axis
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def __call__(self, w):
|
||||
norms = K.sqrt(
|
||||
math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
|
||||
desired = K.clip(norms, 0, self.max_value)
|
||||
return w * (desired / (K.epsilon() + norms))
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def get_config(self):
|
||||
return {'max_value': self.max_value, 'axis': self.axis}
|
||||
|
||||
@ -81,6 +86,8 @@ class MaxNorm(Constraint):
|
||||
@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg')
|
||||
class NonNeg(Constraint):
|
||||
"""Constrains the weights to be non-negative.
|
||||
|
||||
Also available via the shortcut function `tf.keras.constraints.non_neg`.
|
||||
"""
|
||||
|
||||
def __call__(self, w):
|
||||
@ -91,29 +98,33 @@ class NonNeg(Constraint):
|
||||
class UnitNorm(Constraint):
|
||||
"""Constrains the weights incident to each hidden unit to have unit norm.
|
||||
|
||||
Also available via the shortcut function `tf.keras.constraints.unit_norm`.
|
||||
|
||||
Arguments:
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
"""
|
||||
|
||||
def __init__(self, axis=0):
|
||||
self.axis = axis
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def __call__(self, w):
|
||||
return w / (
|
||||
K.epsilon() + K.sqrt(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.square(w), axis=self.axis, keepdims=True)))
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def get_config(self):
|
||||
return {'axis': self.axis}
|
||||
|
||||
@ -125,27 +136,29 @@ class MinMaxNorm(Constraint):
|
||||
Constrains the weights incident to each hidden unit
|
||||
to have the norm between a lower bound and an upper bound.
|
||||
|
||||
Also available via the shortcut function `tf.keras.constraints.min_max_norm`.
|
||||
|
||||
Arguments:
|
||||
min_value: the minimum norm for the incoming weights.
|
||||
max_value: the maximum norm for the incoming weights.
|
||||
rate: rate for enforcing the constraint: weights will be
|
||||
rescaled to yield
|
||||
`(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
|
||||
Effectively, this means that rate=1.0 stands for strict
|
||||
enforcement of the constraint, while rate<1.0 means that
|
||||
weights will be rescaled at each step to slowly move
|
||||
towards a value inside the desired interval.
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
min_value: the minimum norm for the incoming weights.
|
||||
max_value: the maximum norm for the incoming weights.
|
||||
rate: rate for enforcing the constraint: weights will be
|
||||
rescaled to yield
|
||||
`(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
|
||||
Effectively, this means that rate=1.0 stands for strict
|
||||
enforcement of the constraint, while rate<1.0 means that
|
||||
weights will be rescaled at each step to slowly move
|
||||
towards a value inside the desired interval.
|
||||
axis: integer, axis along which to calculate weight norms.
|
||||
For instance, in a `Dense` layer the weight matrix
|
||||
has shape `(input_dim, output_dim)`,
|
||||
set `axis` to `0` to constrain each weight vector
|
||||
of length `(input_dim,)`.
|
||||
In a `Conv2D` layer with `data_format="channels_last"`,
|
||||
the weight tensor has shape
|
||||
`(rows, cols, input_depth, output_depth)`,
|
||||
set `axis` to `[0, 1, 2]`
|
||||
to constrain the weights of each filter tensor of size
|
||||
`(rows, cols, input_depth)`.
|
||||
"""
|
||||
|
||||
def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
|
||||
@ -154,6 +167,7 @@ class MinMaxNorm(Constraint):
|
||||
self.rate = rate
|
||||
self.axis = axis
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def __call__(self, w):
|
||||
norms = K.sqrt(
|
||||
math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
|
||||
@ -162,6 +176,7 @@ class MinMaxNorm(Constraint):
|
||||
(1 - self.rate) * norms)
|
||||
return w * (desired / (K.epsilon() + norms))
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def get_config(self):
|
||||
return {
|
||||
'min_value': self.min_value,
|
||||
@ -176,7 +191,10 @@ class MinMaxNorm(Constraint):
|
||||
class RadialConstraint(Constraint):
|
||||
"""Constrains `Conv2D` kernel weights to be the same for each radius.
|
||||
|
||||
For example, the desired output for the following 4-by-4 kernel::
|
||||
Also available via the shortcut function
|
||||
`tf.keras.constraints.radial_constraint`.
|
||||
|
||||
For example, the desired output for the following 4-by-4 kernel:
|
||||
|
||||
```
|
||||
kernel = [[v_00, v_01, v_02, v_03],
|
||||
@ -200,6 +218,7 @@ class RadialConstraint(Constraint):
|
||||
shape `(rows, cols, input_depth, output_depth)`.
|
||||
"""
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def __call__(self, w):
|
||||
w_shape = w.shape
|
||||
if w_shape.rank is None or w_shape.rank != 4:
|
||||
|
Loading…
Reference in New Issue
Block a user