Back-ticks are now converted to links in the api_docs generator. With the new docs repo we're moving to simplify the docs pipeline, and make everything more readable. By doing this we no longer get test failures for symbols that don't exist (`tf.does_not_exist` will not get a link). There is also no way, not to set custom link text. That's okay. This is the result of the following regex replacement (+ a couple of manual edits.): re: @\{([^$].*?)(\$.+?)?} sub: `\1` Which does the following replacements: "@{tf.symbol}" --> "`tf.symbol`" "@{tf.symbol$link_text}" --> "`tf.symbol`" PiperOrigin-RevId: 208042358
400 lines
16 KiB
Python
400 lines
16 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# =============================================================================
|
|
"""Contains the base Layer class, from which all layers inherit."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import copy
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.keras.engine import base_layer
|
|
from tensorflow.python.ops import variable_scope as vs
|
|
from tensorflow.python.ops import variables as tf_variables
|
|
from tensorflow.python.util import function_utils
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
InputSpec = base_layer.InputSpec # pylint: disable=invalid-name
|
|
|
|
|
|
@tf_export('layers.Layer')
|
|
class Layer(base_layer.Layer):
|
|
"""Base layer class.
|
|
|
|
It is considered legacy, and we recommend the use of `tf.keras.layers.Layer`
|
|
instead.
|
|
|
|
Arguments:
|
|
trainable: Boolean, whether the layer's variables should be trainable.
|
|
name: String name of the layer.
|
|
dtype: Default dtype of the layer's weights (default of `None` means use the
|
|
type of the first input).
|
|
|
|
Read-only properties:
|
|
name: The name of the layer (string).
|
|
dtype: Default dtype of the layer's weights (default of `None` means use the
|
|
type of the first input).
|
|
trainable_variables: List of trainable variables.
|
|
non_trainable_variables: List of non-trainable variables.
|
|
variables: List of all variables of this layer, trainable and
|
|
non-trainable.
|
|
updates: List of update ops of this layer.
|
|
losses: List of losses added by this layer.
|
|
trainable_weights: List of variables to be included in backprop.
|
|
non_trainable_weights: List of variables that should not be
|
|
included in backprop.
|
|
weights: The concatenation of the lists trainable_weights and
|
|
non_trainable_weights (in this order).
|
|
|
|
Mutable properties:
|
|
trainable: Whether the layer should be trained (boolean).
|
|
input_spec: Optional (list of) `InputSpec` object(s) specifying the
|
|
constraints on inputs that can be accepted by the layer.
|
|
"""
|
|
|
|
def __init__(self, trainable=True, name=None, dtype=None,
|
|
**kwargs):
|
|
# For backwards compatibility, legacy layers do not use `ResourceVariable`
|
|
# by default.
|
|
self._use_resource_variables = False
|
|
scope = kwargs.pop('_scope', None)
|
|
self._reuse = kwargs.pop('_reuse', None)
|
|
|
|
# Avoid an incorrect lint error
|
|
self._trainable_weights = []
|
|
self.built = False
|
|
|
|
super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype,
|
|
**kwargs)
|
|
|
|
self._graph = None
|
|
self._call_has_scope_arg = 'scope' in self._call_fn_args
|
|
if scope:
|
|
with vs.variable_scope(scope) as captured_scope:
|
|
self._scope = captured_scope
|
|
else:
|
|
self._scope = None
|
|
self._current_scope = None
|
|
|
|
@property
|
|
def graph(self):
|
|
if context.executing_eagerly():
|
|
raise RuntimeError('Layer.graph not supported when executing eagerly.')
|
|
return self._graph
|
|
|
|
def _init_set_name(self, name):
|
|
# Determine layer name (non-unique).
|
|
if isinstance(name, vs.VariableScope):
|
|
base_name = name.name
|
|
else:
|
|
base_name = name
|
|
self._name = name
|
|
if not name:
|
|
self._name, base_name = self._make_unique_name()
|
|
self._base_name = base_name
|
|
|
|
def _make_unique_name(self, name_uid_map=None, avoid_names=None,
|
|
namespace='', zero_based=False):
|
|
base_name = base_layer.to_snake_case(self.__class__.__name__)
|
|
name = base_layer.unique_layer_name(base_name,
|
|
name_uid_map=name_uid_map,
|
|
avoid_names=avoid_names,
|
|
namespace=namespace,
|
|
zero_based=zero_based)
|
|
return (name, base_name)
|
|
|
|
@property
|
|
def scope_name(self):
|
|
if not self._scope:
|
|
raise ValueError('No name available for layer scope because the layer "' +
|
|
self._name + '" has not been used yet. The scope name ' +
|
|
' is determined the first time the layer instance is ' +
|
|
'called. You must therefore call the layer before ' +
|
|
'querying `scope_name`.')
|
|
return self._scope.name
|
|
|
|
def add_loss(self, losses, inputs=None):
|
|
previous_losses_length = len(self._losses)
|
|
super(Layer, self).add_loss(losses, inputs=inputs)
|
|
# TODO(fchollet): deprecate collection below.
|
|
new_losses = self._losses[previous_losses_length:]
|
|
_add_elements_to_collection(new_losses, ops.GraphKeys.REGULARIZATION_LOSSES)
|
|
|
|
def _name_scope(self):
|
|
"""Determines op naming for the Layer."""
|
|
return self._current_scope.original_name_scope
|
|
|
|
def _set_scope(self, scope=None):
|
|
if self._scope is None:
|
|
# If constructed with _scope=None, lazy setting of scope.
|
|
if self._reuse:
|
|
with vs.variable_scope(
|
|
scope if scope is not None else self._base_name) as captured_scope:
|
|
self._scope = captured_scope
|
|
else:
|
|
with vs.variable_scope(
|
|
scope, default_name=self._base_name) as captured_scope:
|
|
self._scope = captured_scope
|
|
|
|
def add_weight(self,
|
|
name,
|
|
shape,
|
|
dtype=None,
|
|
initializer=None,
|
|
regularizer=None,
|
|
trainable=None,
|
|
constraint=None,
|
|
use_resource=None,
|
|
synchronization=vs.VariableSynchronization.AUTO,
|
|
aggregation=vs.VariableAggregation.NONE,
|
|
partitioner=None):
|
|
"""Adds a new variable to the layer, or gets an existing one; returns it.
|
|
|
|
Arguments:
|
|
name: variable name.
|
|
shape: variable shape.
|
|
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
|
|
initializer: initializer instance (callable).
|
|
regularizer: regularizer instance (callable).
|
|
trainable: whether the variable should be part of the layer's
|
|
"trainable_variables" (e.g. variables, biases)
|
|
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
|
|
Note, if the current variable scope is marked as non-trainable
|
|
then this parameter is ignored and any added variables are also
|
|
marked as non-trainable. `trainable` defaults to `True` unless
|
|
`synchronization` is set to `ON_READ`.
|
|
constraint: constraint instance (callable).
|
|
use_resource: Whether to use `ResourceVariable`.
|
|
synchronization: Indicates when a distributed a variable will be
|
|
aggregated. Accepted values are constants defined in the class
|
|
`tf.VariableSynchronization`. By default the synchronization is set to
|
|
`AUTO` and the current `DistributionStrategy` chooses
|
|
when to synchronize. If `synchronization` is set to `ON_READ`,
|
|
`trainable` must not be set to `True`.
|
|
aggregation: Indicates how a distributed variable will be aggregated.
|
|
Accepted values are constants defined in the class
|
|
`tf.VariableAggregation`.
|
|
partitioner: (optional) partitioner instance (callable). If
|
|
provided, when the requested variable is created it will be split
|
|
into multiple partitions according to `partitioner`. In this case,
|
|
an instance of `PartitionedVariable` is returned. Available
|
|
partitioners include `tf.fixed_size_partitioner` and
|
|
`tf.variable_axis_size_partitioner`. For more details, see the
|
|
documentation of `tf.get_variable` and the "Variable Partitioners
|
|
and Sharding" section of the API guide.
|
|
|
|
Returns:
|
|
The created variable. Usually either a `Variable` or `ResourceVariable`
|
|
instance. If `partitioner` is not `None`, a `PartitionedVariable`
|
|
instance is returned.
|
|
|
|
Raises:
|
|
RuntimeError: If called with partioned variable regularization and
|
|
eager execution is enabled.
|
|
ValueError: When trainable has been set to True with synchronization
|
|
set as `ON_READ`.
|
|
"""
|
|
if synchronization == vs.VariableSynchronization.ON_READ:
|
|
if trainable:
|
|
raise ValueError(
|
|
'Synchronization value can be set to '
|
|
'VariableSynchronization.ON_READ only for non-trainable variables. '
|
|
'You have specified trainable=True and '
|
|
'synchronization=VariableSynchronization.ON_READ.')
|
|
else:
|
|
# Set trainable to be false when variable is to be synced on read.
|
|
trainable = False
|
|
elif trainable is None:
|
|
trainable = True
|
|
|
|
def _should_add_regularizer(variable, existing_variable_set):
|
|
if isinstance(variable, tf_variables.PartitionedVariable):
|
|
for var in variable:
|
|
if var in existing_variable_set:
|
|
return False
|
|
return True
|
|
else:
|
|
return variable not in existing_variable_set
|
|
|
|
init_graph = None
|
|
if not context.executing_eagerly():
|
|
default_graph = ops.get_default_graph()
|
|
if default_graph.building_function:
|
|
with ops.init_scope():
|
|
# Retrieve the variables from the graph into which variables
|
|
# will be lifted; if initialization ops will be lifted into
|
|
# the eager context, then there is nothing to retrieve, since variable
|
|
# collections are not supported when eager execution is enabled.
|
|
if not context.executing_eagerly():
|
|
init_graph = ops.get_default_graph()
|
|
existing_variables = set(tf_variables.global_variables())
|
|
else:
|
|
# Initialization ops will not be lifted out of the default graph.
|
|
init_graph = default_graph
|
|
existing_variables = set(tf_variables.global_variables())
|
|
|
|
if dtype is None:
|
|
dtype = self.dtype or dtypes.float32
|
|
|
|
self._set_scope(None)
|
|
reuse = self.built or self._reuse
|
|
prev_len_trainable = len(self._trainable_weights)
|
|
with vs.variable_scope(
|
|
self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
|
|
self._current_scope = scope
|
|
with ops.name_scope(self._name_scope()):
|
|
use_resource = (use_resource or
|
|
self._use_resource_variables or
|
|
scope.use_resource)
|
|
variable = super(Layer, self).add_weight(
|
|
name,
|
|
shape,
|
|
dtype=dtypes.as_dtype(dtype),
|
|
initializer=initializer or scope.initializer,
|
|
trainable=trainable,
|
|
constraint=constraint,
|
|
partitioner=partitioner,
|
|
use_resource=use_resource,
|
|
synchronization=synchronization,
|
|
aggregation=aggregation,
|
|
getter=vs.get_variable)
|
|
|
|
if regularizer:
|
|
if context.executing_eagerly() or _should_add_regularizer(
|
|
variable, existing_variables):
|
|
self._handle_weight_regularization(name, variable, regularizer)
|
|
|
|
if init_graph is not None:
|
|
# Handle edge case where a custom getter has overridden `trainable`.
|
|
# There is one known occurrence of this, in unit test
|
|
# testBasicRNNCellNotTrainable in
|
|
# contrib.rnn.python.kernel_tests.core_rnn_cell_test
|
|
with init_graph.as_default():
|
|
trainable_variables = tf_variables.trainable_variables()
|
|
if (trainable and self.trainable and
|
|
variable not in trainable_variables):
|
|
# A custom getter / variable scope overrode the trainable flag.
|
|
extra_trainable_vars = self._trainable_weights[prev_len_trainable:]
|
|
self._trainable_weights = self._trainable_weights[
|
|
:prev_len_trainable]
|
|
self._non_trainable_weights += extra_trainable_vars
|
|
return variable
|
|
|
|
def __call__(self, inputs, *args, **kwargs):
|
|
"""Wraps `call`, applying pre- and post-processing steps.
|
|
|
|
Arguments:
|
|
inputs: input tensor(s).
|
|
*args: additional positional arguments to be passed to `self.call`.
|
|
**kwargs: additional keyword arguments to be passed to `self.call`.
|
|
**Note**: kwarg `scope` is reserved for use by the layer.
|
|
|
|
Returns:
|
|
Output tensor(s).
|
|
|
|
Note:
|
|
- If the layer's `call` method takes a `scope` keyword argument,
|
|
this argument will be automatically set to the current variable scope.
|
|
- If the layer's `call` method takes a `mask` argument (as some Keras
|
|
layers do), its default value will be set to the mask generated
|
|
for `inputs` by the previous layer (if `input` did come from
|
|
a layer that generated a corresponding mask, i.e. if it came from
|
|
a Keras layer with masking support.
|
|
|
|
Raises:
|
|
ValueError: if the layer's `call` method returns None (an invalid value).
|
|
"""
|
|
self._set_scope(kwargs.pop('scope', None))
|
|
|
|
if not context.executing_eagerly():
|
|
try:
|
|
# Set layer's "graph" at build time
|
|
self._graph = ops._get_graph_from_inputs(nest.flatten(inputs), # pylint: disable=protected-access
|
|
graph=self._graph)
|
|
except ValueError as e:
|
|
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
|
|
|
|
if self.built:
|
|
try:
|
|
# Some classes which inherit from Layer do not use its constructor, so
|
|
# rather than initializing to None we check for an AttributeError.
|
|
scope_context_manager = self._always_reuse_variable_scope
|
|
except AttributeError:
|
|
# From this point we will always set reuse=True, so create a "final"
|
|
# variable scope with this setting. We avoid re-creating variable scopes
|
|
# after this point as an optimization.
|
|
self._always_reuse_variable_scope = vs.variable_scope(
|
|
self._scope, reuse=True, auxiliary_name_scope=False)
|
|
scope_context_manager = self._always_reuse_variable_scope
|
|
else:
|
|
scope_context_manager = vs.variable_scope(
|
|
self._scope, reuse=self._reuse, auxiliary_name_scope=False)
|
|
|
|
with scope_context_manager as scope:
|
|
self._current_scope = scope
|
|
|
|
try:
|
|
call_has_scope_arg = self._call_has_scope_arg
|
|
except AttributeError:
|
|
self._call_fn_args = function_utils.fn_args(self.call)
|
|
self._call_has_scope_arg = 'scope' in self._call_fn_args
|
|
call_has_scope_arg = self._call_has_scope_arg
|
|
if call_has_scope_arg:
|
|
kwargs['scope'] = scope
|
|
|
|
# Actually call layer
|
|
outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
|
|
|
|
if not context.executing_eagerly():
|
|
# Update global default collections.
|
|
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
|
|
return outputs
|
|
|
|
def __deepcopy__(self, memo):
|
|
no_copy = set(['_graph'])
|
|
shallow_copy = set(['_scope', '_always_reuse_variable_scope'])
|
|
cls = self.__class__
|
|
result = cls.__new__(cls)
|
|
memo[id(self)] = result
|
|
for k, v in self.__dict__.items():
|
|
if k in no_copy:
|
|
setattr(result, k, v)
|
|
elif k in shallow_copy:
|
|
setattr(result, k, copy.copy(v))
|
|
elif base_layer.is_tensor_or_tensor_list(v):
|
|
setattr(result, k, v)
|
|
else:
|
|
setattr(result, k, copy.deepcopy(v, memo))
|
|
return result
|
|
|
|
|
|
def _add_elements_to_collection(elements, collection_list):
|
|
if context.executing_eagerly():
|
|
raise RuntimeError('Using collections from Layers not supported in Eager '
|
|
'mode. Tried to add %s to %s' % (elements,
|
|
collection_list))
|
|
elements = nest.flatten(elements)
|
|
collection_list = nest.flatten(collection_list)
|
|
for name in collection_list:
|
|
collection = ops.get_collection_ref(name)
|
|
collection_set = set(collection)
|
|
for element in elements:
|
|
if element not in collection_set:
|
|
collection.append(element)
|