From 2c6bfdd663fd602af2952b13bdb671358833a362 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Fri, 13 Mar 2020 21:06:01 -0700 Subject: [PATCH] Move tf.layers to keras folder and export from there. PiperOrigin-RevId: 300878712 Change-Id: I3648257c5e49f03a6d163110f4cfed10b57444ea --- tensorflow/python/BUILD | 157 +- .../python/keras/legacy_tf_layers/BUILD | 195 +++ .../python/keras/legacy_tf_layers/__init__.py | 0 .../python/keras/legacy_tf_layers/base.py | 593 +++++++ .../legacy_tf_layers}/base_test.py | 4 +- .../keras/legacy_tf_layers/convolutional.py | 1469 +++++++++++++++++ .../legacy_tf_layers}/convolutional_test.py | 2 +- .../python/keras/legacy_tf_layers/core.py | 338 ++++ .../legacy_tf_layers}/core_test.py | 2 +- .../keras/legacy_tf_layers/normalization.py | 342 ++++ .../legacy_tf_layers}/normalization_test.py | 4 +- .../python/keras/legacy_tf_layers/pooling.py | 470 ++++++ .../legacy_tf_layers}/pooling_test.py | 2 +- tensorflow/python/layers/base.py | 577 +------ tensorflow/python/layers/convolutional.py | 1449 +--------------- tensorflow/python/layers/core.py | 315 +--- tensorflow/python/layers/normalization.py | 317 +--- tensorflow/python/layers/pooling.py | 453 +---- ...ensorflow.layers.-average-pooling1-d.pbtxt | 4 +- ...ensorflow.layers.-average-pooling2-d.pbtxt | 4 +- ...ensorflow.layers.-average-pooling3-d.pbtxt | 4 +- ...nsorflow.layers.-batch-normalization.pbtxt | 4 +- .../v1/tensorflow.layers.-conv1-d.pbtxt | 4 +- ...tensorflow.layers.-conv2-d-transpose.pbtxt | 4 +- .../v1/tensorflow.layers.-conv2-d.pbtxt | 4 +- ...tensorflow.layers.-conv3-d-transpose.pbtxt | 4 +- .../v1/tensorflow.layers.-conv3-d.pbtxt | 4 +- .../golden/v1/tensorflow.layers.-dense.pbtxt | 4 +- .../v1/tensorflow.layers.-dropout.pbtxt | 4 +- .../v1/tensorflow.layers.-flatten.pbtxt | 4 +- .../golden/v1/tensorflow.layers.-layer.pbtxt | 2 +- .../tensorflow.layers.-max-pooling1-d.pbtxt | 4 +- .../tensorflow.layers.-max-pooling2-d.pbtxt | 4 +- .../tensorflow.layers.-max-pooling3-d.pbtxt | 4 +- ...tensorflow.layers.-separable-conv1-d.pbtxt | 4 +- ...tensorflow.layers.-separable-conv2-d.pbtxt | 4 +- ...perimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt | 2 +- ....experimental.nn.-tf-lite-r-n-n-cell.pbtxt | 2 +- ...flow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt | 2 +- ...orflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt | 2 +- ...nsorflow.nn.rnn_cell.-device-wrapper.pbtxt | 2 +- ...sorflow.nn.rnn_cell.-dropout-wrapper.pbtxt | 2 +- .../tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt | 2 +- ...tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt | 2 +- ...orflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt | 2 +- .../tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt | 2 +- ...orflow.nn.rnn_cell.-residual-wrapper.pbtxt | 2 +- 47 files changed, 3516 insertions(+), 3265 deletions(-) create mode 100644 tensorflow/python/keras/legacy_tf_layers/BUILD create mode 100644 tensorflow/python/keras/legacy_tf_layers/__init__.py create mode 100644 tensorflow/python/keras/legacy_tf_layers/base.py rename tensorflow/python/{layers => keras/legacy_tf_layers}/base_test.py (99%) create mode 100644 tensorflow/python/keras/legacy_tf_layers/convolutional.py rename tensorflow/python/{layers => keras/legacy_tf_layers}/convolutional_test.py (99%) create mode 100644 tensorflow/python/keras/legacy_tf_layers/core.py rename tensorflow/python/{layers => keras/legacy_tf_layers}/core_test.py (99%) create mode 100644 tensorflow/python/keras/legacy_tf_layers/normalization.py rename tensorflow/python/{layers => keras/legacy_tf_layers}/normalization_test.py (99%) create mode 100644 tensorflow/python/keras/legacy_tf_layers/pooling.py rename tensorflow/python/{layers => keras/legacy_tf_layers}/pooling_test.py (99%) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index ff536c391bc..d932899ab0d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2062,8 +2062,8 @@ tf_py_test( main = "framework/registry_test.py", python_version = "PY3", deps = [ + ":client_testlib", ":framework_for_generated_wrappers", - "//tensorflow/python:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -3895,8 +3895,8 @@ cuda_py_test( srcs = ["training/experimental/mixed_precision_test.py"], python_version = "PY3", deps = [ + ":client_testlib", ":mixed_precision", - "//tensorflow/python:client_testlib", "@absl_py//absl/testing:parameterized", ], ) @@ -5879,7 +5879,7 @@ filegroup( "//tensorflow/core/util:port", # util_port "//tensorflow/core/util/tensor_bundle", # checkpoint_reader "//tensorflow/lite/toco/python:toco_python_api", # toco - "//tensorflow/python:tf_session_helper", # tf_session + ":tf_session_helper", # tf_session "//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib "//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer "//tensorflow/tools/graph_transforms:transform_graph_lib", # transform_graph @@ -7021,6 +7021,7 @@ py_tests( ], ) +# TODO(scottzhu): Move all the tf.layer related targets. py_library( name = "layers_base", srcs = [ @@ -7029,19 +7030,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":array_ops", - ":control_flow_ops", - ":framework_for_generated_wrappers", - ":layers_util", - ":platform", - ":smart_cond", - ":tensor_util", - ":util", - ":variable_scope", - ":variables", - "//tensorflow/python/eager:context", - "//tensorflow/python/keras:engine", - "//third_party/py/numpy", + "//tensorflow/python/keras/legacy_tf_layers:layers_base", ], ) @@ -7070,134 +7059,12 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":array_ops", - ":array_ops_gen", - ":control_flow_ops", - ":framework", - ":framework_for_generated_wrappers", - ":init_ops", ":layers_base", - ":math_ops", - ":nn", - ":nn_ops", - ":platform", - ":resource_variable_ops", - ":resource_variable_ops_gen", - ":standard_ops", - ":state_ops", - ":training", - ":util", - ":variable_scope", - ":variables", - "//tensorflow/python/eager:context", - "//tensorflow/python/keras/layers", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -tf_py_test( - name = "layers_base_test", - size = "small", - srcs = ["layers/base_test.py"], - main = "layers/base_test.py", - python_version = "PY3", - deps = [ - ":array_ops", - ":client_testlib", - ":framework_for_generated_wrappers", - ":framework_test_lib", - ":init_ops", - ":layers", - ":layers_base", - ":math_ops", - ":random_ops", - ":variable_scope", - "//tensorflow/python/eager:context", - ], -) - -tf_py_test( - name = "layers_core_test", - size = "small", - srcs = ["layers/core_test.py"], - main = "layers/core_test.py", - python_version = "PY3", - deps = [ - ":array_ops", - ":client_testlib", - ":framework_for_generated_wrappers", - ":framework_test_lib", - ":layers", - ":math_ops", - ":nn_ops", - ":random_ops", - ":variable_scope", - ":variables", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "layers_convolutional_test", - size = "small", - srcs = ["layers/convolutional_test.py"], - main = "layers/convolutional_test.py", - python_version = "PY3", - deps = [ - ":client_testlib", - ":framework_for_generated_wrappers", - ":framework_test_lib", - ":layers", - ":math_ops", - ":nn_ops", - ":random_ops", - ], -) - -tf_py_test( - name = "layers_utils_test", - size = "small", - srcs = ["layers/utils_test.py"], - main = "layers/utils_test.py", - python_version = "PY3", - deps = [ - ":client_testlib", - ":layers", - ], -) - -tf_py_test( - name = "layers_pooling_test", - size = "small", - srcs = ["layers/pooling_test.py"], - main = "layers/pooling_test.py", - python_version = "PY3", - deps = [ - ":client_testlib", - ":framework_test_lib", - ":layers", - ":random_ops", - ], -) - -cuda_py_test( - name = "layers_normalization_test", - size = "medium", - srcs = ["layers/normalization_test.py"], - main = "layers/normalization_test.py", - python_version = "PY3", - shard_count = 10, - deps = [ - ":array_ops", - ":client_testlib", - ":framework_for_generated_wrappers", - ":framework_test_lib", - ":layers", - ":math_ops", - ":random_ops", - ":variables", - "//third_party/py/numpy", + "//tensorflow/python/keras/engine:input_spec", + "//tensorflow/python/keras/legacy_tf_layers:convolutional", + "//tensorflow/python/keras/legacy_tf_layers:core", + "//tensorflow/python/keras/legacy_tf_layers:normalization", + "//tensorflow/python/keras/legacy_tf_layers:pooling", ], ) @@ -7711,7 +7578,7 @@ tf_py_test( deps = [ ":client_testlib", ":graph_placer", - "//tensorflow/python:math_ops", + ":math_ops", ], ) @@ -8145,6 +8012,6 @@ cuda_py_test( srcs = ["ops/raw_ops_test.py"], python_version = "PY3", deps = [ - "//tensorflow/python:client_testlib", + ":client_testlib", ], ) diff --git a/tensorflow/python/keras/legacy_tf_layers/BUILD b/tensorflow/python/keras/legacy_tf_layers/BUILD new file mode 100644 index 00000000000..6789acc74cb --- /dev/null +++ b/tensorflow/python/keras/legacy_tf_layers/BUILD @@ -0,0 +1,195 @@ +# Description: +# Contains the legacy TF layers (internal TensorFlow version). + +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +package( + default_visibility = ["//tensorflow:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +py_library( + name = "layers_base", + srcs = [ + "__init__.py", + "base.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/keras:backend", + "//tensorflow/python/keras/engine:base_layer", + "//tensorflow/python/keras/mixed_precision/experimental:policy", + "//tensorflow/python/training/tracking:base", + ], +) + +py_library( + name = "convolutional", + srcs = ["convolutional.py"], + deps = [ + ":layers_base", + "//tensorflow/python:init_ops", + "//tensorflow/python:util", + "//tensorflow/python/keras/layers", + ], +) + +py_library( + name = "core", + srcs = ["core.py"], + deps = [ + ":layers_base", + "//tensorflow/python:init_ops", + "//tensorflow/python:util", + "//tensorflow/python/keras/layers", + ], +) + +py_library( + name = "normalization", + srcs = ["normalization.py"], + deps = [ + ":layers_base", + "//tensorflow/python:init_ops", + "//tensorflow/python:util", + "//tensorflow/python/keras/layers:normalization", + ], +) + +py_library( + name = "pooling", + srcs = ["pooling.py"], + deps = [ + ":layers_base", + "//tensorflow/python:util", + "//tensorflow/python/keras/layers", + ], +) + +tf_py_test( + name = "base_test", + size = "small", + srcs = ["base_test.py"], + main = "base_test.py", + python_version = "PY3", + deps = [ + ":core", + ":layers_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:random_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/keras/engine:base_layer", + "//tensorflow/python/keras/engine:input_spec", + ], +) + +tf_py_test( + name = "core_test", + size = "small", + srcs = ["core_test.py"], + main = "core_test.py", + python_version = "PY3", + deps = [ + ":core", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + +tf_py_test( + name = "convolutional_test", + size = "small", + srcs = ["convolutional_test.py"], + main = "convolutional_test.py", + python_version = "PY3", + deps = [ + ":convolutional", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "pooling_test", + size = "small", + srcs = ["pooling_test.py"], + main = "pooling_test.py", + python_version = "PY3", + deps = [ + ":pooling", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:random_ops", + ], +) + +cuda_py_test( + name = "normalization_test", + size = "medium", + srcs = ["normalization_test.py"], + main = "normalization_test.py", + python_version = "PY3", + shard_count = 10, + deps = [ + ":convolutional", + ":normalization", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:saver", + "//tensorflow/python:training_lib", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) diff --git a/tensorflow/python/keras/legacy_tf_layers/__init__.py b/tensorflow/python/keras/legacy_tf_layers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/python/keras/legacy_tf_layers/base.py b/tensorflow/python/keras/legacy_tf_layers/base.py new file mode 100644 index 00000000000..5a944703af9 --- /dev/null +++ b/tensorflow/python/keras/legacy_tf_layers/base.py @@ -0,0 +1,593 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Contains the base Layer class, from which all layers inherit.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras import backend +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils +from tensorflow.python.util import nest +from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.tf_export import tf_export + +# Avoid breaking users who directly import this symbol from this file. +# TODO(fchollet): remove this. +InputSpec = base_layer.InputSpec # pylint: disable=invalid-name + +_KERAS_STYLE_SCOPE = False + + +@tf_export(v1=['layers.experimental.keras_style_scope']) +@tf_contextlib.contextmanager +def keras_style_scope(): + """Use Keras-style variable management. + + All tf.layers and tf RNN cells created in this scope use Keras-style + variable management. Creating such layers with a scope= argument is + disallowed, and reuse=True is disallowed. + + The purpose of this scope is to allow users of existing layers to + slowly transition to a Keras layers API without breaking existing + functionality. + + One example of this is when using TensorFlow's RNN classes with Keras + Models or Networks. Because Keras models do not properly set variable + scopes, users of RNNs may either accidentally share scopes between two + different models, or get errors about variables that already exist. + + Example: + + ```python + class RNNModel(tf.keras.Model): + + def __init__(self, name): + super(RNNModel, self).__init__(name=name) + self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell( + [tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)]) + + def call(self, input, state): + return self.rnn(input, state) + + model_1 = RNNModel("model_1") + model_2 = RNNModel("model_2") + + # OK + output_1, next_state_1 = model_1(input, state) + # Raises an error about trying to create an already existing variable. + output_2, next_state_2 = model_2(input, state) + ``` + + The solution is to wrap the model construction and execution in a keras-style + scope: + + ```python + with keras_style_scope(): + model_1 = RNNModel("model_1") + model_2 = RNNModel("model_2") + + # model_1 and model_2 are guaranteed to create their own variables. + output_1, next_state_1 = model_1(input, state) + output_2, next_state_2 = model_2(input, state) + + assert len(model_1.weights) > 0 + assert len(model_2.weights) > 0 + assert(model_1.weights != model_2.weights) + ``` + + Yields: + A keras layer style scope. + """ + global _KERAS_STYLE_SCOPE + stack = _KERAS_STYLE_SCOPE + _KERAS_STYLE_SCOPE = True + try: + yield + finally: + _KERAS_STYLE_SCOPE = stack + + +@tf_export(v1=['layers.experimental.set_keras_style']) +def set_keras_style(): + """Use Keras-style variable management. + + All tf.layers and tf RNN cells created after keras style ha been enabled + use Keras-style variable management. Creating such layers with a + scope= argument is disallowed, and reuse=True is disallowed. + + The purpose of this function is to allow users of existing layers to + slowly transition to Keras layers API without breaking existing + functionality. + + For more details, see the documentation for `keras_style_scope`. + + Note, once keras style has been set, it is set globally for the entire + program and cannot be unset. + + Example: + + ```python + set_keras_style() + + model_1 = RNNModel(name="model_1") + model_2 = RNNModel(name="model_2") + + # model_1 and model_2 are guaranteed to create their own variables. + output_1, next_state_1 = model_1(input, state) + output_2, next_state_2 = model_2(input, state) + + assert len(model_1.weights) > 0 + assert len(model_2.weights) > 0 + assert(model_1.weights != model_2.weights) + ``` + """ + global _KERAS_STYLE_SCOPE + _KERAS_STYLE_SCOPE = True + + +def _is_in_keras_style_scope(): + global _KERAS_STYLE_SCOPE + return _KERAS_STYLE_SCOPE + + +@tf_export(v1=['layers.Layer']) +class Layer(base_layer.Layer): + """Base layer class. + + It is considered legacy, and we recommend the use of `tf.keras.layers.Layer` + instead. + + Arguments: + trainable: Boolean, whether the layer's variables should be trainable. + name: String name of the layer. + dtype: Default dtype of the layer's weights (default of `None` means use the + type of the first input). + + Read-only properties: + name: The name of the layer (string). + dtype: Default dtype of the layer's weights (default of `None` means use the + type of the first input). + trainable_variables: List of trainable variables. + non_trainable_variables: List of non-trainable variables. + variables: List of all variables of this layer, trainable and + non-trainable. + updates: List of update ops of this layer. + losses: List of losses added by this layer. + trainable_weights: List of variables to be included in backprop. + non_trainable_weights: List of variables that should not be + included in backprop. + weights: The concatenation of the lists trainable_weights and + non_trainable_weights (in this order). + + Mutable properties: + trainable: Whether the layer should be trained (boolean). + input_spec: Optional (list of) `InputSpec` object(s) specifying the + constraints on inputs that can be accepted by the layer. + """ + + def __init__(self, trainable=True, name=None, dtype=None, + **kwargs): + # For backwards compatibility, legacy layers do not use `ResourceVariable` + # by default. + self._use_resource_variables = False + scope = kwargs.pop('_scope', None) + self._reuse = kwargs.pop('_reuse', None) + + # Avoid an incorrect lint error + self._trainable_weights = [] + self.built = False + + if dtype is None: + # Indicates to infer dtype from inputs. When the V2 dtype behavior is + # enabled, Keras layers default their dtype to floatx instead, so we pass + # an "_infer" policy to keep the old V1 behavior. + dtype = policy.Policy('_infer') + + if 'autocast' not in kwargs: + kwargs['autocast'] = False + + super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype, + **kwargs) + + if _is_in_keras_style_scope(): + if scope is not None: + raise ValueError( + 'scope argument not allowed when keras style layers are enabled, ' + 'but saw: {}'.format(scope)) + if self._reuse is not None: + raise ValueError( + 'reuse argument not allowed when keras style layers are enabled, ' + 'but saw: {}'.format(self._reuse)) + self._keras_style = True + else: + self._keras_style = False + + self._call_has_scope_arg = 'scope' in self._call_fn_args + if scope: + with vs.variable_scope(scope) as captured_scope: + self._scope = captured_scope + else: + self._scope = None + self._current_scope = None + + # We no longer track graph in tf.layers layers. This property is only kept to + # maintain API backward compatibility. + @property + @deprecation.deprecated( + date=None, + instructions='Stop using this property because tf.layers layers no ' + 'longer track their graph.') + def graph(self): + if context.executing_eagerly(): + raise RuntimeError('Layer.graph not supported when executing eagerly.') + return None + + def _init_set_name(self, name): + # Determine layer name (non-unique). + if isinstance(name, vs.VariableScope): + base_name = name.name + self._name, _ = self._make_unique_name() + else: + base_name = name + self._name = name + if not name: + self._name, base_name = self._make_unique_name() + self._base_name = base_name + + def _make_unique_name(self, name_uid_map=None, avoid_names=None, + namespace='', zero_based=False): + base_name = base_layer.to_snake_case(self.__class__.__name__) + name = backend.unique_object_name( + base_name, + name_uid_map=name_uid_map, + avoid_names=avoid_names, + namespace=namespace, + zero_based=zero_based) + return (name, base_name) + + @property + def scope_name(self): + if not self._scope: + raise ValueError('No name available for layer scope because the layer "' + + self._name + '" has not been used yet. The scope name ' + + ' is determined the first time the layer instance is ' + + 'called. You must therefore call the layer before ' + + 'querying `scope_name`.') + return self._scope.name + + def add_loss(self, losses, inputs=None): + previous_losses_length = len(self._losses) + previous_callable_losses_length = len(self._callable_losses) + super(Layer, self).add_loss(losses, inputs=inputs) + if not context.executing_eagerly(): + # TODO(fchollet): deprecate collection below. + new_losses = self._losses[previous_losses_length:] + new_callable_losses = self._callable_losses[ + previous_callable_losses_length:] + for regularizer in new_callable_losses: + loss_tensor = regularizer() + if loss_tensor is not None: + new_losses.append(loss_tensor) + _add_elements_to_collection( + new_losses, + ops.GraphKeys.REGULARIZATION_LOSSES) + + def _name_scope(self): + """Determines op naming for the Layer.""" + if self._keras_style: + return super(Layer, self)._name_scope() + return self._current_scope.original_name_scope + + def _set_scope(self, scope=None): + if self._scope is None: + # If constructed with _scope=None, lazy setting of scope. + if self._reuse: + with vs.variable_scope( + scope if scope is not None else self._base_name) as captured_scope: + self._scope = captured_scope + else: + with vs.variable_scope( + scope, default_name=self._base_name) as captured_scope: + self._scope = captured_scope + + def add_weight(self, + name, + shape, + dtype=None, + initializer=None, + regularizer=None, + trainable=None, + constraint=None, + use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, + partitioner=None, + **kwargs): + """Adds a new variable to the layer, or gets an existing one; returns it. + + Arguments: + name: variable name. + shape: variable shape. + dtype: The type of the variable. Defaults to `self.dtype` or `float32`. + initializer: initializer instance (callable). + regularizer: regularizer instance (callable). + trainable: whether the variable should be part of the layer's + "trainable_variables" (e.g. variables, biases) + or "non_trainable_variables" (e.g. BatchNorm mean, stddev). + Note, if the current variable scope is marked as non-trainable + then this parameter is ignored and any added variables are also + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. + constraint: constraint instance (callable). + use_resource: Whether to use `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. + partitioner: (optional) partitioner instance (callable). If + provided, when the requested variable is created it will be split + into multiple partitions according to `partitioner`. In this case, + an instance of `PartitionedVariable` is returned. Available + partitioners include `tf.compat.v1.fixed_size_partitioner` and + `tf.compat.v1.variable_axis_size_partitioner`. For more details, see + the documentation of `tf.compat.v1.get_variable` and the "Variable + Partitioners and Sharding" section of the API guide. + **kwargs: Additional keyword arguments. + + Returns: + The created variable. Usually either a `Variable` or `ResourceVariable` + instance. If `partitioner` is not `None`, a `PartitionedVariable` + instance is returned. + + Raises: + RuntimeError: If called with partitioned variable regularization and + eager execution is enabled. + ValueError: When trainable has been set to True with synchronization + set as `ON_READ`. + """ + for kwarg in kwargs: + if kwarg != 'experimental_autocast': + raise TypeError('Unknown keyword argument:', kwarg) + if self._keras_style: + return super(Layer, self).add_weight( + name=name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=trainable and self.trainable, + constraint=constraint, + use_resource=use_resource, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, + partitioner=partitioner, + **kwargs) + + if synchronization == vs.VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.') + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True + + def _should_add_regularizer(variable, existing_variable_set): + if isinstance(variable, tf_variables.PartitionedVariable): + for var in variable: + if var in existing_variable_set: + return False + return True + else: + return variable not in existing_variable_set + + init_graph = None + if not context.executing_eagerly(): + default_graph = ops.get_default_graph() + if default_graph.building_function: + with ops.init_scope(): + # Retrieve the variables from the graph into which variables + # will be lifted; if initialization ops will be lifted into + # the eager context, then there is nothing to retrieve, since variable + # collections are not supported when eager execution is enabled. + if not context.executing_eagerly(): + init_graph = ops.get_default_graph() + existing_variables = set(tf_variables.global_variables()) + else: + # Initialization ops will not be lifted out of the default graph. + init_graph = default_graph + existing_variables = set(tf_variables.global_variables()) + + if dtype is None: + dtype = self.dtype or dtypes.float32 + + self._set_scope(None) + reuse = self.built or self._reuse + prev_len_trainable = len(self._trainable_weights) + with vs.variable_scope( + self._scope, reuse=reuse, auxiliary_name_scope=False) as scope: + self._current_scope = scope + with ops.name_scope(self._name_scope(), skip_on_eager=False): + use_resource = (use_resource or + self._use_resource_variables or + scope.use_resource) + if initializer is None: + initializer = scope.initializer + variable = super(Layer, self).add_weight( + name, + shape, + dtype=dtypes.as_dtype(dtype), + initializer=initializer, + trainable=trainable and self.trainable, + constraint=constraint, + partitioner=partitioner, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation, + getter=vs.get_variable, + **kwargs) + + if regularizer: + if (ops.executing_eagerly_outside_functions() + or _should_add_regularizer(variable, existing_variables)): + self._handle_weight_regularization(name, variable, regularizer) + + if init_graph is not None: + # Handle edge case where a custom getter has overridden `trainable`. + # There is one known occurrence of this, in unit test + # testBasicRNNCellNotTrainable in + # contrib.rnn.python.kernel_tests.core_rnn_cell_test + with init_graph.as_default(): + trainable_variables = tf_variables.trainable_variables() + if (trainable and self.trainable and + variable not in trainable_variables): + # A custom getter / variable scope overrode the trainable flag. + extra_trainable_vars = self._trainable_weights[prev_len_trainable:] + self._trainable_weights = self._trainable_weights[ + :prev_len_trainable] + self._non_trainable_weights += extra_trainable_vars + return variable + + def __call__(self, inputs, *args, **kwargs): + """Wraps `call`, applying pre- and post-processing steps. + + Arguments: + inputs: input tensor(s). + *args: additional positional arguments to be passed to `self.call`. + **kwargs: additional keyword arguments to be passed to `self.call`. + **Note**: kwarg `scope` is reserved for use by the layer. + + Returns: + Output tensor(s). + + Note: + - If the layer's `call` method takes a `scope` keyword argument, + this argument will be automatically set to the current variable scope. + - If the layer's `call` method takes a `mask` argument (as some Keras + layers do), its default value will be set to the mask generated + for `inputs` by the previous layer (if `input` did come from + a layer that generated a corresponding mask, i.e. if it came from + a Keras layer with masking support. + + Raises: + ValueError: if the layer's `call` method returns None (an invalid value). + """ + scope = kwargs.pop('scope', None) + + if self._keras_style: + if scope is not None: + raise ValueError( + 'scope argument not allowed when keras style layers are enabled, ' + 'but saw: {}'.format(scope)) + return super(Layer, self).__call__(inputs, *args, **kwargs) + + self._set_scope(scope) + + if self.built: + try: + # Some classes which inherit from Layer do not use its constructor, so + # rather than initializing to None we check for an AttributeError. + scope_context_manager = self._always_reuse_variable_scope + except AttributeError: + # From this point we will always set reuse=True, so create a "final" + # variable scope with this setting. We avoid re-creating variable scopes + # after this point as an optimization. + self._always_reuse_variable_scope = vs.variable_scope( + self._scope, reuse=True, auxiliary_name_scope=False) + scope_context_manager = self._always_reuse_variable_scope + else: + scope_context_manager = vs.variable_scope( + self._scope, reuse=self._reuse, auxiliary_name_scope=False) + + with scope_context_manager as scope: + self._current_scope = scope + + try: + call_has_scope_arg = self._call_has_scope_arg + except AttributeError: + self._call_fn_args = function_utils.fn_args(self.call) + self._call_has_scope_arg = 'scope' in self._call_fn_args + call_has_scope_arg = self._call_has_scope_arg + if call_has_scope_arg: + kwargs['scope'] = scope + + # Actually call layer + outputs = super(Layer, self).__call__(inputs, *args, **kwargs) + + if not context.executing_eagerly(): + # Update global default collections. + _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) + return outputs + + def __deepcopy__(self, memo): + no_copy = set(['_graph', '_thread_local', '_metrics_lock']) + shallow_copy = set(['_scope', '_always_reuse_variable_scope']) + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k in no_copy: + setattr(result, k, v) + elif k in shallow_copy: + setattr(result, k, copy.copy(v)) + elif base_layer.is_tensor_or_tensor_list(v): + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + + def __setattr__(self, value, name): + # By-pass the automatic dependency tracking performed by the parent Layer. + super(trackable.Trackable, self).__setattr__(value, name) + + @property + def _is_legacy_layer(self): + """Used by keras to check compatibility. This should not be overridden.""" + return True + + +def _add_elements_to_collection(elements, collection_list): + if context.executing_eagerly(): + raise RuntimeError('Using collections from Layers not supported in Eager ' + 'mode. Tried to add %s to %s' % (elements, + collection_list)) + elements = nest.flatten(elements) + collection_list = nest.flatten(collection_list) + for name in collection_list: + collection = ops.get_collection_ref(name) + collection_set = {id(e) for e in collection} + for element in elements: + if id(element) not in collection_set: + collection.append(element) diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/keras/legacy_tf_layers/base_test.py similarity index 99% rename from tensorflow/python/layers/base_test.py rename to tensorflow/python/keras/legacy_tf_layers/base_test.py index 321a1854819..106517623a8 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/keras/legacy_tf_layers/base_test.py @@ -30,8 +30,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import base_layer as keras_base_layer from tensorflow.python.keras.engine import input_spec -from tensorflow.python.layers import base as base_layers -from tensorflow.python.layers import core as core_layers +from tensorflow.python.keras.legacy_tf_layers import base as base_layers +from tensorflow.python.keras.legacy_tf_layers import core as core_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/python/keras/legacy_tf_layers/convolutional.py b/tensorflow/python/keras/legacy_tf_layers/convolutional.py new file mode 100644 index 00000000000..4c91251a0e7 --- /dev/null +++ b/tensorflow/python/keras/legacy_tf_layers/convolutional.py @@ -0,0 +1,1469 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Contains the convolutional layer classes and their functional aliases. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras import layers as keras_layers +from tensorflow.python.keras.legacy_tf_layers import base +from tensorflow.python.ops import init_ops +from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export + + +@tf_export(v1=['layers.Conv1D']) +class Conv1D(keras_layers.Conv1D, base.Layer): + """1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, filters, + kernel_size, + strides=1, + padding='valid', + data_format='channels_last', + dilation_rate=1, + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(Conv1D, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, **kwargs) + + +@deprecation.deprecated( + date=None, + instructions='Use `tf.keras.layers.Conv1D` instead.') +@tf_export(v1=['layers.conv1d']) +def conv1d(inputs, + filters, + kernel_size, + strides=1, + padding='valid', + data_format='channels_last', + dilation_rate=1, + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for 1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = Conv1D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.Conv2D']) +class Conv2D(keras_layers.Conv2D, base.Layer): + """2D convolution layer (e.g. spatial convolution over images). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, filters, + kernel_size, + strides=(1, 1), + padding='valid', + data_format='channels_last', + dilation_rate=(1, 1), + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(Conv2D, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, **kwargs) + + +@deprecation.deprecated( + date=None, + instructions='Use `tf.keras.layers.Conv2D` instead.') +@tf_export(v1=['layers.conv2d']) +def conv2d(inputs, + filters, + kernel_size, + strides=(1, 1), + padding='valid', + data_format='channels_last', + dilation_rate=(1, 1), + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for the 2D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = Conv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.Conv3D']) +class Conv3D(keras_layers.Conv3D, base.Layer): + """3D convolution layer (e.g. spatial convolution over volumes). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, filters, + kernel_size, + strides=(1, 1, 1), + padding='valid', + data_format='channels_last', + dilation_rate=(1, 1, 1), + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(Conv3D, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, **kwargs) + + +@deprecation.deprecated( + date=None, + instructions='Use `tf.keras.layers.Conv3D` instead.') +@tf_export(v1=['layers.conv3d']) +def conv3d(inputs, + filters, + kernel_size, + strides=(1, 1, 1), + padding='valid', + data_format='channels_last', + dilation_rate=(1, 1, 1), + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for the 3D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = Conv3D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.SeparableConv1D']) +class SeparableConv1D(keras_layers.SeparableConv1D, base.Layer): + """Depthwise separable 1D convolution. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. + It then optionally applies an activation function to produce the final output. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A single integer specifying the spatial + dimensions of the filters. + strides: A single integer specifying the strides + of the convolution. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: A single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + depth_multiplier: The number of depthwise convolution output channels for + each input channel. The total number of depthwise convolution output + channels will be equal to `num_filters_in * depth_multiplier`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + depthwise_initializer: An initializer for the depthwise convolution kernel. + pointwise_initializer: An initializer for the pointwise convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used for + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, filters, + kernel_size, + strides=1, + padding='valid', + data_format='channels_last', + dilation_rate=1, + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer=None, + pointwise_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(SeparableConv1D, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + pointwise_initializer=pointwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + pointwise_regularizer=pointwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + pointwise_constraint=pointwise_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + **kwargs) + + +@tf_export(v1=['layers.SeparableConv2D']) +class SeparableConv2D(keras_layers.SeparableConv2D, base.Layer): + """Depthwise separable 2D convolution. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. + It then optionally applies an activation function to produce the final output. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A tuple or list of 2 integers specifying the spatial + dimensions of the filters. Can be a single integer to specify the same + value for all spatial dimensions. + strides: A tuple or list of 2 positive integers specifying the strides + of the convolution. Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + depth_multiplier: The number of depthwise convolution output channels for + each input channel. The total number of depthwise convolution output + channels will be equal to `num_filters_in * depth_multiplier`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + depthwise_initializer: An initializer for the depthwise convolution kernel. + pointwise_initializer: An initializer for the pointwise convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used for + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, filters, + kernel_size, + strides=(1, 1), + padding='valid', + data_format='channels_last', + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer=None, + pointwise_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(SeparableConv2D, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + pointwise_initializer=pointwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + pointwise_regularizer=pointwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + pointwise_constraint=pointwise_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + **kwargs) + + +@deprecation.deprecated( + date=None, + instructions='Use `tf.keras.layers.SeparableConv1D` instead.') +@tf_export(v1=['layers.separable_conv1d']) +def separable_conv1d(inputs, + filters, + kernel_size, + strides=1, + padding='valid', + data_format='channels_last', + dilation_rate=1, + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer=None, + pointwise_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for the depthwise separable 1D convolution layer. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. + It then optionally applies an activation function to produce the final output. + + Arguments: + inputs: Input tensor. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A single integer specifying the spatial + dimensions of the filters. + strides: A single integer specifying the strides + of the convolution. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: A single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + depth_multiplier: The number of depthwise convolution output channels for + each input channel. The total number of depthwise convolution output + channels will be equal to `num_filters_in * depth_multiplier`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + depthwise_initializer: An initializer for the depthwise convolution kernel. + pointwise_initializer: An initializer for the pointwise convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used for + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = SeparableConv1D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + pointwise_initializer=pointwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + pointwise_regularizer=pointwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + pointwise_constraint=pointwise_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs) + + +@deprecation.deprecated( + date=None, + instructions='Use `tf.keras.layers.SeparableConv2D` instead.') +@tf_export(v1=['layers.separable_conv2d']) +def separable_conv2d(inputs, + filters, + kernel_size, + strides=(1, 1), + padding='valid', + data_format='channels_last', + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer=None, + pointwise_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for the depthwise separable 2D convolution layer. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. + It then optionally applies an activation function to produce the final output. + + Arguments: + inputs: Input tensor. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A tuple or list of 2 integers specifying the spatial + dimensions of the filters. Can be a single integer to specify the same + value for all spatial dimensions. + strides: A tuple or list of 2 positive integers specifying the strides + of the convolution. Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + depth_multiplier: The number of depthwise convolution output channels for + each input channel. The total number of depthwise convolution output + channels will be equal to `num_filters_in * depth_multiplier`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + depthwise_initializer: An initializer for the depthwise convolution kernel. + pointwise_initializer: An initializer for the pointwise convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used for + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = SeparableConv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + pointwise_initializer=pointwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + pointwise_regularizer=pointwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + pointwise_constraint=pointwise_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.Conv2DTranspose']) +class Conv2DTranspose(keras_layers.Conv2DTranspose, base.Layer): + """Transposed 2D convolution layer (sometimes called 2D Deconvolution). + + The need for transposed convolutions generally arises + from the desire to use a transformation going in the opposite direction + of a normal convolution, i.e., from something that has the shape of the + output of some convolution to something that has the shape of its input + while maintaining a connectivity pattern that is compatible with + said convolution. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A tuple or list of 2 positive integers specifying the spatial + dimensions of the filters. Can be a single integer to specify the same + value for all spatial dimensions. + strides: A tuple or list of 2 positive integers specifying the strides + of the convolution. Can be a single integer to specify the same value for + all spatial dimensions. + padding: one of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, filters, + kernel_size, + strides=(1, 1), + padding='valid', + data_format='channels_last', + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(Conv2DTranspose, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + **kwargs) + + +@deprecation.deprecated( + date=None, + instructions='Use `tf.keras.layers.Conv2DTranspose` instead.') +@tf_export(v1=['layers.conv2d_transpose']) +def conv2d_transpose(inputs, + filters, + kernel_size, + strides=(1, 1), + padding='valid', + data_format='channels_last', + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for transposed 2D convolution layer. + + The need for transposed convolutions generally arises + from the desire to use a transformation going in the opposite direction + of a normal convolution, i.e., from something that has the shape of the + output of some convolution to something that has the shape of its input + while maintaining a connectivity pattern that is compatible with + said convolution. + + Arguments: + inputs: Input tensor. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A tuple or list of 2 positive integers specifying the spatial + dimensions of the filters. Can be a single integer to specify the same + value for all spatial dimensions. + strides: A tuple or list of 2 positive integers specifying the strides + of the convolution. Can be a single integer to specify the same value for + all spatial dimensions. + padding: one of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + activation: Activation function. Set it to `None` to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If `None`, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = Conv2DTranspose( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.Conv3DTranspose']) +class Conv3DTranspose(keras_layers.Conv3DTranspose, base.Layer): + """Transposed 3D convolution layer (sometimes called 3D Deconvolution). + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for all spatial + dimensions. + strides: An integer or tuple/list of 3 integers, specifying the strides + of the convolution along the depth, height and width. + Can be a single integer to specify the same value for all spatial + dimensions. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + activation: Activation function. Set it to `None` to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If `None`, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, + filters, + kernel_size, + strides=(1, 1, 1), + padding='valid', + data_format='channels_last', + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(Conv3DTranspose, self).__init__( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + **kwargs) + + +@deprecation.deprecated( + date=None, + instructions='Use `tf.keras.layers.Conv3DTranspose` instead.') +@tf_export(v1=['layers.conv3d_transpose']) +def conv3d_transpose(inputs, + filters, + kernel_size, + strides=(1, 1, 1), + padding='valid', + data_format='channels_last', + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for transposed 3D convolution layer. + + Arguments: + inputs: Input tensor. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: A tuple or list of 3 positive integers specifying the spatial + dimensions of the filters. Can be a single integer to specify the same + value for all spatial dimensions. + strides: A tuple or list of 3 positive integers specifying the strides + of the convolution. Can be a single integer to specify the same value for + all spatial dimensions. + padding: one of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, the default + initializer will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = Conv3DTranspose( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs) + + +# Aliases + +Convolution1D = Conv1D +Convolution2D = Conv2D +Convolution3D = Conv3D +SeparableConvolution2D = SeparableConv2D +Convolution2DTranspose = Deconvolution2D = Deconv2D = Conv2DTranspose +Convolution3DTranspose = Deconvolution3D = Deconv3D = Conv3DTranspose +convolution1d = conv1d +convolution2d = conv2d +convolution3d = conv3d +separable_convolution2d = separable_conv2d +convolution2d_transpose = deconvolution2d = deconv2d = conv2d_transpose +convolution3d_transpose = deconvolution3d = deconv3d = conv3d_transpose diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/keras/legacy_tf_layers/convolutional_test.py similarity index 99% rename from tensorflow/python/layers/convolutional_test.py rename to tensorflow/python/keras/legacy_tf_layers/convolutional_test.py index a3e493edfea..b0eeede8737 100644 --- a/tensorflow/python/layers/convolutional_test.py +++ b/tensorflow/python/keras/legacy_tf_layers/convolutional_test.py @@ -23,7 +23,7 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.layers import convolutional as conv_layers +from tensorflow.python.keras.legacy_tf_layers import convolutional as conv_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/python/keras/legacy_tf_layers/core.py b/tensorflow/python/keras/legacy_tf_layers/core.py new file mode 100644 index 00000000000..78ddf2547ae --- /dev/null +++ b/tensorflow/python/keras/legacy_tf_layers/core.py @@ -0,0 +1,338 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Contains the core layers: Dense, Dropout. + +Also contains their functional aliases. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.keras import layers as keras_layers +from tensorflow.python.keras.legacy_tf_layers import base +from tensorflow.python.ops import init_ops +from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export + + +@tf_export(v1=['layers.Dense']) +class Dense(keras_layers.Dense, base.Layer): + """Densely-connected layer class. + + This layer implements the operation: + `outputs = activation(inputs * kernel + bias)` + Where `activation` is the activation function passed as the `activation` + argument (if not `None`), `kernel` is a weights matrix created by the layer, + and `bias` is a bias vector created by the layer + (only if `use_bias` is `True`). + + Arguments: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (callable). Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the default + initializer used by `tf.compat.v1.get_variable`. + bias_initializer: Initializer function for the bias. + kernel_regularizer: Regularizer function for the weight matrix. + bias_regularizer: Regularizer function for the bias. + activity_regularizer: Regularizer function for the output. + kernel_constraint: An optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: An optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such cases. + _reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (callable). + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: Initializer instance (or name) for the kernel matrix. + bias_initializer: Initializer instance (or name) for the bias. + kernel_regularizer: Regularizer instance for the kernel matrix (callable) + bias_regularizer: Regularizer instance for the bias (callable). + activity_regularizer: Regularizer instance for the output (callable) + kernel_constraint: Constraint function for the kernel matrix. + bias_constraint: Constraint function for the bias. + kernel: Weight matrix (TensorFlow variable or tensor). + bias: Bias vector, if applicable (TensorFlow variable or tensor). + """ + + def __init__(self, units, + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs): + super(Dense, self).__init__(units=units, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + **kwargs) + + +@deprecation.deprecated( + date=None, instructions='Use keras.layers.Dense instead.') +@tf_export(v1=['layers.dense']) +def dense( + inputs, units, + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + reuse=None): + """Functional interface for the densely-connected layer. + + This layer implements the operation: + `outputs = activation(inputs * kernel + bias)` + where `activation` is the activation function passed as the `activation` + argument (if not `None`), `kernel` is a weights matrix created by the layer, + and `bias` is a bias vector created by the layer + (only if `use_bias` is `True`). + + Arguments: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (callable). Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the default + initializer used by `tf.compat.v1.get_variable`. + bias_initializer: Initializer function for the bias. + kernel_regularizer: Regularizer function for the weight matrix. + bias_regularizer: Regularizer function for the bias. + activity_regularizer: Regularizer function for the output. + kernel_constraint: An optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The function + must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are + not safe to use when doing asynchronous distributed training. + bias_constraint: An optional projection function to be applied to the + bias after being updated by an `Optimizer`. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: String, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor the same shape as `inputs` except the last dimension is of + size `units`. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = Dense(units, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + trainable=trainable, + name=name, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +@tf_export(v1=['layers.Dropout']) +class Dropout(keras_layers.Dropout, base.Layer): + """Applies Dropout to the input. + + Dropout consists in randomly setting a fraction `rate` of input units to 0 + at each update during training time, which helps prevent overfitting. + The units that are kept are scaled by `1 / (1 - rate)`, so that their + sum is unchanged at training time and inference time. + + Arguments: + rate: The dropout rate, between 0 and 1. E.g. `rate=0.1` would drop out + 10% of input units. + noise_shape: 1D tensor of type `int32` representing the shape of the + binary dropout mask that will be multiplied with the input. + For instance, if your inputs have shape + `(batch_size, timesteps, features)`, and you want the dropout mask + to be the same for all timesteps, you can use + `noise_shape=[batch_size, 1, features]`. + seed: A Python integer. Used to create random seeds. See + `tf.compat.v1.set_random_seed`. + for behavior. + name: The name of the layer (string). + """ + + def __init__(self, rate=0.5, + noise_shape=None, + seed=None, + name=None, + **kwargs): + super(Dropout, self).__init__(rate=rate, + noise_shape=noise_shape, + seed=seed, + name=name, + **kwargs) + + def call(self, inputs, training=False): + return super(Dropout, self).call(inputs, training=training) + + +@deprecation.deprecated( + date=None, + instructions='Use keras.layers.dropout instead.') +@tf_export(v1=['layers.dropout']) +def dropout(inputs, + rate=0.5, + noise_shape=None, + seed=None, + training=False, + name=None): + """Applies Dropout to the input. + + Dropout consists in randomly setting a fraction `rate` of input units to 0 + at each update during training time, which helps prevent overfitting. + The units that are kept are scaled by `1 / (1 - rate)`, so that their + sum is unchanged at training time and inference time. + + Arguments: + inputs: Tensor input. + rate: The dropout rate, between 0 and 1. E.g. "rate=0.1" would drop out + 10% of input units. + noise_shape: 1D tensor of type `int32` representing the shape of the + binary dropout mask that will be multiplied with the input. + For instance, if your inputs have shape + `(batch_size, timesteps, features)`, and you want the dropout mask + to be the same for all timesteps, you can use + `noise_shape=[batch_size, 1, features]`. + seed: A Python integer. Used to create random seeds. See + `tf.compat.v1.set_random_seed` + for behavior. + training: Either a Python boolean, or a TensorFlow boolean scalar tensor + (e.g. a placeholder). Whether to return the output in training mode + (apply dropout) or in inference mode (return the input untouched). + name: The name of the layer (string). + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = Dropout(rate, noise_shape=noise_shape, seed=seed, name=name) + return layer.apply(inputs, training=training) + + +@tf_export(v1=['layers.Flatten']) +class Flatten(keras_layers.Flatten, base.Layer): + """Flattens an input tensor while preserving the batch axis (axis 0). + + Arguments: + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + + Examples: + + ``` + x = tf.compat.v1.placeholder(shape=(None, 4, 4), dtype='float32') + y = Flatten()(x) + # now `y` has shape `(None, 16)` + + x = tf.compat.v1.placeholder(shape=(None, 3, None), dtype='float32') + y = Flatten()(x) + # now `y` has shape `(None, None)` + ``` + """ + pass + + +@deprecation.deprecated( + date=None, + instructions='Use keras.layers.Flatten instead.') +@tf_export(v1=['layers.flatten']) +def flatten(inputs, name=None, data_format='channels_last'): + """Flattens an input tensor while preserving the batch axis (axis 0). + + Arguments: + inputs: Tensor input. + name: The name of the layer (string). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + Returns: + Reshaped tensor. + + Examples: + + ``` + x = tf.compat.v1.placeholder(shape=(None, 4, 4), dtype='float32') + y = flatten(x) + # now `y` has shape `(None, 16)` + + x = tf.compat.v1.placeholder(shape=(None, 3, None), dtype='float32') + y = flatten(x) + # now `y` has shape `(None, None)` + ``` + """ + layer = Flatten(name=name, data_format=data_format) + return layer.apply(inputs) + + +# Aliases + +FullyConnected = Dense +fully_connected = dense diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/keras/legacy_tf_layers/core_test.py similarity index 99% rename from tensorflow/python/layers/core_test.py rename to tensorflow/python/keras/legacy_tf_layers/core_test.py index afd288739b6..732c2558920 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/keras/legacy_tf_layers/core_test.py @@ -29,7 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util -from tensorflow.python.layers import core as core_layers +from tensorflow.python.keras.legacy_tf_layers import core as core_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/python/keras/legacy_tf_layers/normalization.py b/tensorflow/python/keras/legacy_tf_layers/normalization.py new file mode 100644 index 00000000000..d874882aed1 --- /dev/null +++ b/tensorflow/python/keras/legacy_tf_layers/normalization.py @@ -0,0 +1,342 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Contains the normalization layer classes and their functional aliases. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.keras.layers import normalization as keras_normalization +from tensorflow.python.keras.legacy_tf_layers import base +from tensorflow.python.ops import init_ops +from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export + + +@tf_export(v1=['layers.BatchNormalization']) +class BatchNormalization(keras_normalization.BatchNormalization, base.Layer): + """Batch Normalization layer from (Ioffe et al., 2015). + + Keras APIs handle BatchNormalization updates to the moving_mean and + moving_variance as part of their `fit()` and `evaluate()` loops. However, if a + custom training loop is used with an instance of `Model`, these updates need + to be explicitly included. Here's a simple example of how it can be done: + + ```python + # model is an instance of Model that contains BatchNormalization layer. + update_ops = model.get_updates_for(None) + model.get_updates_for(features) + train_op = optimizer.minimize(loss) + train_op = tf.group([train_op, update_ops]) + ``` + + Arguments: + axis: An `int` or list of `int`, the axis or axes that should be normalized, + typically the features axis/axes. For instance, after a `Conv2D` layer + with `data_format="channels_first"`, set `axis=1`. If a list of axes is + provided, each axis in `axis` will be normalized + simultaneously. Default is `-1` which uses the last axis. Note: when + using multi-axis batch norm, the `beta`, `gamma`, `moving_mean`, and + `moving_variance` variables are the same rank as the input Tensor, + with dimension size 1 in all reduced (non-axis) dimensions). + momentum: Momentum for the moving average. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the + next layer is linear (also e.g. `nn.relu`), this can be disabled since the + scaling can be done by the next layer. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + moving_mean_initializer: Initializer for the moving mean. + moving_variance_initializer: Initializer for the moving variance. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: An optional projection function to be applied to the `beta` + weight after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected variable and must return the projected + variable (which must have the same shape). Constraints are not safe to use + when doing asynchronous distributed training. + gamma_constraint: An optional projection function to be applied to the + `gamma` weight after being updated by an `Optimizer`. + renorm: Whether to use Batch Renormalization (Ioffe, 2017). This adds extra + variables during training. The inference is the same for either value of + this parameter. + renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to + scalar `Tensors` used to clip the renorm correction. The correction `(r, + d)` is used as `corrected_value = normalized_value * r + d`, with `r` + clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, + dmax are set to inf, 0, inf, respectively. + renorm_momentum: Momentum used to update the moving means and standard + deviations with renorm. Unlike `momentum`, this affects training and + should be neither too small (which would add noise) nor too large (which + would give stale estimates). Note that `momentum` is still applied to get + the means and variances for inference. + fused: if `None` or `True`, use a faster, fused implementation if possible. + If `False`, use the system recommended implementation. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, + which means batch normalization is performed across the whole batch. When + `virtual_batch_size` is not `None`, instead perform "Ghost Batch + Normalization", which creates virtual sub-batches which are each + normalized separately (with shared gamma, beta, and moving statistics). + Must divide the actual batch size during execution. + adjustment: A function taking the `Tensor` containing the (dynamic) shape of + the input tensor and returning a pair (scale, bias) to apply to the + normalized values (before gamma and beta), only during training. For + example, if axis==-1, + `adjustment = lambda shape: ( + tf.random.uniform(shape[-1:], 0.93, 1.07), + tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized + value by up to 7% up or down, then shift the result by up to 0.1 + (with independent scaling and bias for each feature but shared + across all examples), and finally apply gamma and/or beta. If + `None`, no adjustment is applied. Cannot be specified if + virtual_batch_size is specified. + name: A string, the name of the layer. + References: + Batch Normalization - Accelerating Deep Network Training by Reducing + Internal Covariate Shift: + [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) + ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) + Batch Renormalization - Towards Reducing Minibatch Dependence in + Batch-Normalized Models: + [Ioffe, + 2017](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models) + ([pdf](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models.pdf)) + """ + + def __init__(self, + axis=-1, + momentum=0.99, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer=init_ops.zeros_initializer(), + gamma_initializer=init_ops.ones_initializer(), + moving_mean_initializer=init_ops.zeros_initializer(), + moving_variance_initializer=init_ops.ones_initializer(), + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99, + fused=None, + trainable=True, + virtual_batch_size=None, + adjustment=None, + name=None, + **kwargs): + super(BatchNormalization, self).__init__( + axis=axis, + momentum=momentum, + epsilon=epsilon, + center=center, + scale=scale, + beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + moving_mean_initializer=moving_mean_initializer, + moving_variance_initializer=moving_variance_initializer, + beta_regularizer=beta_regularizer, + gamma_regularizer=gamma_regularizer, + beta_constraint=beta_constraint, + gamma_constraint=gamma_constraint, + renorm=renorm, + renorm_clipping=renorm_clipping, + renorm_momentum=renorm_momentum, + fused=fused, + trainable=trainable, + virtual_batch_size=virtual_batch_size, + adjustment=adjustment, + name=name, + **kwargs) + + def call(self, inputs, training=False): + return super(BatchNormalization, self).call(inputs, training=training) + + +@deprecation.deprecated( + date=None, instructions='Use keras.layers.BatchNormalization instead. In ' + 'particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not ' + 'be used (consult the `tf.keras.layers.BatchNormalization` ' + 'documentation).') +@tf_export(v1=['layers.batch_normalization']) +def batch_normalization(inputs, + axis=-1, + momentum=0.99, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer=init_ops.zeros_initializer(), + gamma_initializer=init_ops.ones_initializer(), + moving_mean_initializer=init_ops.zeros_initializer(), + moving_variance_initializer=init_ops.ones_initializer(), + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + training=False, + trainable=True, + name=None, + reuse=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99, + fused=None, + virtual_batch_size=None, + adjustment=None): + """Functional interface for the batch normalization layer from_config(Ioffe et al., 2015). + + Note: when training, the moving_mean and moving_variance need to be updated. + By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they + need to be executed alongside the `train_op`. Also, be sure to add any + batch_normalization ops before getting the update_ops collection. Otherwise, + update_ops will be empty, and training/inference will not work properly. For + example: + + ```python + x_norm = tf.compat.v1.layers.batch_normalization(x, training=training) + + # ... + + update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS) + train_op = optimizer.minimize(loss) + train_op = tf.group([train_op, update_ops]) + ``` + + Arguments: + inputs: Tensor input. + axis: An `int`, the axis that should be normalized (typically the features + axis). For instance, after a `Convolution2D` layer with + `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. + momentum: Momentum for the moving average. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the + next layer is linear (also e.g. `nn.relu`), this can be disabled since the + scaling can be done by the next layer. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + moving_mean_initializer: Initializer for the moving mean. + moving_variance_initializer: Initializer for the moving variance. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: An optional projection function to be applied to the `beta` + weight after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected variable and must return the projected + variable (which must have the same shape). Constraints are not safe to use + when doing asynchronous distributed training. + gamma_constraint: An optional projection function to be applied to the + `gamma` weight after being updated by an `Optimizer`. + training: Either a Python boolean, or a TensorFlow boolean scalar tensor + (e.g. a placeholder). Whether to return the output in training mode + (normalized with statistics of the current batch) or in inference mode + (normalized with moving statistics). **NOTE**: make sure to set this + parameter correctly, or else your training/inference will not work + properly. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + name: String, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer by the same + name. + renorm: Whether to use Batch Renormalization (Ioffe, 2017). This adds extra + variables during training. The inference is the same for either value of + this parameter. + renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to + scalar `Tensors` used to clip the renorm correction. The correction `(r, + d)` is used as `corrected_value = normalized_value * r + d`, with `r` + clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, + dmax are set to inf, 0, inf, respectively. + renorm_momentum: Momentum used to update the moving means and standard + deviations with renorm. Unlike `momentum`, this affects training and + should be neither too small (which would add noise) nor too large (which + would give stale estimates). Note that `momentum` is still applied to get + the means and variances for inference. + fused: if `None` or `True`, use a faster, fused implementation if possible. + If `False`, use the system recommended implementation. + virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, + which means batch normalization is performed across the whole batch. When + `virtual_batch_size` is not `None`, instead perform "Ghost Batch + Normalization", which creates virtual sub-batches which are each + normalized separately (with shared gamma, beta, and moving statistics). + Must divide the actual batch size during execution. + adjustment: A function taking the `Tensor` containing the (dynamic) shape of + the input tensor and returning a pair (scale, bias) to apply to the + normalized values (before gamma and beta), only during training. For + example, if axis==-1, + `adjustment = lambda shape: ( + tf.random.uniform(shape[-1:], 0.93, 1.07), + tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized + value by up to 7% up or down, then shift the result by up to 0.1 + (with independent scaling and bias for each feature but shared + across all examples), and finally apply gamma and/or beta. If + `None`, no adjustment is applied. Cannot be specified if + virtual_batch_size is specified. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + References: + Batch Normalization - Accelerating Deep Network Training by Reducing + Internal Covariate Shift: + [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) + ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) + Batch Renormalization - Towards Reducing Minibatch Dependence in + Batch-Normalized Models: + [Ioffe, + 2017](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models) + ([pdf](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models.pdf)) + """ + layer = BatchNormalization( + axis=axis, + momentum=momentum, + epsilon=epsilon, + center=center, + scale=scale, + beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + moving_mean_initializer=moving_mean_initializer, + moving_variance_initializer=moving_variance_initializer, + beta_regularizer=beta_regularizer, + gamma_regularizer=gamma_regularizer, + beta_constraint=beta_constraint, + gamma_constraint=gamma_constraint, + renorm=renorm, + renorm_clipping=renorm_clipping, + renorm_momentum=renorm_momentum, + fused=fused, + trainable=trainable, + virtual_batch_size=virtual_batch_size, + adjustment=adjustment, + name=name, + _reuse=reuse, + _scope=name) + return layer.apply(inputs, training=training) + + +# Aliases + +BatchNorm = BatchNormalization +batch_norm = batch_normalization diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/keras/legacy_tf_layers/normalization_test.py similarity index 99% rename from tensorflow/python/layers/normalization_test.py rename to tensorflow/python/keras/legacy_tf_layers/normalization_test.py index 7672a9da84a..668fab885cc 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/keras/legacy_tf_layers/normalization_test.py @@ -26,8 +26,8 @@ from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.layers import convolutional as conv_layers -from tensorflow.python.layers import normalization as normalization_layers +from tensorflow.python.keras.legacy_tf_layers import convolutional as conv_layers +from tensorflow.python.keras.legacy_tf_layers import normalization as normalization_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/python/keras/legacy_tf_layers/pooling.py b/tensorflow/python/keras/legacy_tf_layers/pooling.py new file mode 100644 index 00000000000..2e1ba36c5b9 --- /dev/null +++ b/tensorflow/python/keras/legacy_tf_layers/pooling.py @@ -0,0 +1,470 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Contains the pooling layer classes and their functional aliases. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras import layers as keras_layers +from tensorflow.python.keras.legacy_tf_layers import base +from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export + + +@tf_export(v1=['layers.AveragePooling1D']) +class AveragePooling1D(keras_layers.AveragePooling1D, base.Layer): + """Average Pooling layer for 1D inputs. + + Arguments: + pool_size: An integer or tuple/list of a single integer, + representing the size of the pooling window. + strides: An integer or tuple/list of a single integer, specifying the + strides of the pooling operation. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + name: A string, the name of the layer. + """ + + def __init__(self, pool_size, strides, + padding='valid', data_format='channels_last', + name=None, **kwargs): + if strides is None: + raise ValueError('Argument `strides` must not be None.') + super(AveragePooling1D, self).__init__( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + name=name, + **kwargs) + + +@deprecation.deprecated( + date=None, instructions='Use keras.layers.AveragePooling1D instead.') +@tf_export(v1=['layers.average_pooling1d']) +def average_pooling1d(inputs, pool_size, strides, + padding='valid', data_format='channels_last', + name=None): + """Average Pooling layer for 1D inputs. + + Arguments: + inputs: The tensor over which to pool. Must have rank 3. + pool_size: An integer or tuple/list of a single integer, + representing the size of the pooling window. + strides: An integer or tuple/list of a single integer, specifying the + strides of the pooling operation. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + name: A string, the name of the layer. + + Returns: + The output tensor, of rank 3. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = AveragePooling1D(pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + name=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.MaxPooling1D']) +class MaxPooling1D(keras_layers.MaxPooling1D, base.Layer): + """Max Pooling layer for 1D inputs. + + Arguments: + pool_size: An integer or tuple/list of a single integer, + representing the size of the pooling window. + strides: An integer or tuple/list of a single integer, specifying the + strides of the pooling operation. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + name: A string, the name of the layer. + """ + + def __init__(self, pool_size, strides, + padding='valid', data_format='channels_last', + name=None, **kwargs): + if strides is None: + raise ValueError('Argument `strides` must not be None.') + super(MaxPooling1D, self).__init__( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + name=name, + **kwargs) + + +@deprecation.deprecated( + date=None, instructions='Use keras.layers.MaxPooling1D instead.') +@tf_export(v1=['layers.max_pooling1d']) +def max_pooling1d(inputs, pool_size, strides, + padding='valid', data_format='channels_last', + name=None): + """Max Pooling layer for 1D inputs. + + Arguments: + inputs: The tensor over which to pool. Must have rank 3. + pool_size: An integer or tuple/list of a single integer, + representing the size of the pooling window. + strides: An integer or tuple/list of a single integer, specifying the + strides of the pooling operation. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + name: A string, the name of the layer. + + Returns: + The output tensor, of rank 3. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = MaxPooling1D(pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + name=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.AveragePooling2D']) +class AveragePooling2D(keras_layers.AveragePooling2D, base.Layer): + """Average pooling layer for 2D inputs (e.g. images). + + Arguments: + pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width) + specifying the size of the pooling window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the pooling operation. + Can be a single integer to specify the same value for + all spatial dimensions. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string. The ordering of the dimensions in the inputs. + `channels_last` (default) and `channels_first` are supported. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + name: A string, the name of the layer. + """ + + def __init__(self, pool_size, strides, + padding='valid', data_format='channels_last', + name=None, **kwargs): + if strides is None: + raise ValueError('Argument `strides` must not be None.') + super(AveragePooling2D, self).__init__( + pool_size=pool_size, strides=strides, + padding=padding, data_format=data_format, name=name, **kwargs) + + +@deprecation.deprecated( + date=None, instructions='Use keras.layers.AveragePooling2D instead.') +@tf_export(v1=['layers.average_pooling2d']) +def average_pooling2d(inputs, + pool_size, strides, + padding='valid', data_format='channels_last', + name=None): + """Average pooling layer for 2D inputs (e.g. images). + + Arguments: + inputs: The tensor over which to pool. Must have rank 4. + pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width) + specifying the size of the pooling window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the pooling operation. + Can be a single integer to specify the same value for + all spatial dimensions. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string. The ordering of the dimensions in the inputs. + `channels_last` (default) and `channels_first` are supported. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + name: A string, the name of the layer. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = AveragePooling2D(pool_size=pool_size, strides=strides, + padding=padding, data_format=data_format, + name=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.MaxPooling2D']) +class MaxPooling2D(keras_layers.MaxPooling2D, base.Layer): + """Max pooling layer for 2D inputs (e.g. images). + + Arguments: + pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width) + specifying the size of the pooling window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the pooling operation. + Can be a single integer to specify the same value for + all spatial dimensions. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string. The ordering of the dimensions in the inputs. + `channels_last` (default) and `channels_first` are supported. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + name: A string, the name of the layer. + """ + + def __init__(self, pool_size, strides, + padding='valid', data_format='channels_last', + name=None, **kwargs): + if strides is None: + raise ValueError('Argument `strides` must not be None.') + super(MaxPooling2D, self).__init__( + pool_size=pool_size, strides=strides, + padding=padding, data_format=data_format, name=name, **kwargs) + + +@deprecation.deprecated( + date=None, instructions='Use keras.layers.MaxPooling2D instead.') +@tf_export(v1=['layers.max_pooling2d']) +def max_pooling2d(inputs, + pool_size, strides, + padding='valid', data_format='channels_last', + name=None): + """Max pooling layer for 2D inputs (e.g. images). + + Arguments: + inputs: The tensor over which to pool. Must have rank 4. + pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width) + specifying the size of the pooling window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the pooling operation. + Can be a single integer to specify the same value for + all spatial dimensions. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string. The ordering of the dimensions in the inputs. + `channels_last` (default) and `channels_first` are supported. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + name: A string, the name of the layer. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = MaxPooling2D(pool_size=pool_size, strides=strides, + padding=padding, data_format=data_format, + name=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.AveragePooling3D']) +class AveragePooling3D(keras_layers.AveragePooling3D, base.Layer): + """Average pooling layer for 3D inputs (e.g. volumes). + + Arguments: + pool_size: An integer or tuple/list of 3 integers: + (pool_depth, pool_height, pool_width) + specifying the size of the pooling window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the pooling operation. + Can be a single integer to specify the same value for + all spatial dimensions. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string. The ordering of the dimensions in the inputs. + `channels_last` (default) and `channels_first` are supported. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + name: A string, the name of the layer. + """ + + def __init__(self, pool_size, strides, + padding='valid', data_format='channels_last', + name=None, **kwargs): + if strides is None: + raise ValueError('Argument `strides` must not be None.') + super(AveragePooling3D, self).__init__( + pool_size=pool_size, strides=strides, + padding=padding, data_format=data_format, name=name, **kwargs) + + +@deprecation.deprecated( + date=None, instructions='Use keras.layers.AveragePooling3D instead.') +@tf_export(v1=['layers.average_pooling3d']) +def average_pooling3d(inputs, + pool_size, strides, + padding='valid', data_format='channels_last', + name=None): + """Average pooling layer for 3D inputs (e.g. volumes). + + Arguments: + inputs: The tensor over which to pool. Must have rank 5. + pool_size: An integer or tuple/list of 3 integers: + (pool_depth, pool_height, pool_width) + specifying the size of the pooling window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the pooling operation. + Can be a single integer to specify the same value for + all spatial dimensions. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string. The ordering of the dimensions in the inputs. + `channels_last` (default) and `channels_first` are supported. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + name: A string, the name of the layer. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = AveragePooling3D(pool_size=pool_size, strides=strides, + padding=padding, data_format=data_format, + name=name) + return layer.apply(inputs) + + +@tf_export(v1=['layers.MaxPooling3D']) +class MaxPooling3D(keras_layers.MaxPooling3D, base.Layer): + """Max pooling layer for 3D inputs (e.g. volumes). + + Arguments: + pool_size: An integer or tuple/list of 3 integers: + (pool_depth, pool_height, pool_width) + specifying the size of the pooling window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the pooling operation. + Can be a single integer to specify the same value for + all spatial dimensions. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string. The ordering of the dimensions in the inputs. + `channels_last` (default) and `channels_first` are supported. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + name: A string, the name of the layer. + """ + + def __init__(self, pool_size, strides, + padding='valid', data_format='channels_last', + name=None, **kwargs): + if strides is None: + raise ValueError('Argument `strides` must not be None.') + super(MaxPooling3D, self).__init__( + pool_size=pool_size, strides=strides, + padding=padding, data_format=data_format, name=name, **kwargs) + + +@deprecation.deprecated( + date=None, instructions='Use keras.layers.MaxPooling3D instead.') +@tf_export(v1=['layers.max_pooling3d']) +def max_pooling3d(inputs, + pool_size, strides, + padding='valid', data_format='channels_last', + name=None): + """Max pooling layer for 3D inputs (e.g. + + volumes). + + Arguments: + inputs: The tensor over which to pool. Must have rank 5. + pool_size: An integer or tuple/list of 3 integers: (pool_depth, pool_height, + pool_width) specifying the size of the pooling window. Can be a single + integer to specify the same value for all spatial dimensions. + strides: An integer or tuple/list of 3 integers, specifying the strides of + the pooling operation. Can be a single integer to specify the same value + for all spatial dimensions. + padding: A string. The padding method, either 'valid' or 'same'. + Case-insensitive. + data_format: A string. The ordering of the dimensions in the inputs. + `channels_last` (default) and `channels_first` are supported. + `channels_last` corresponds to inputs with shape `(batch, depth, height, + width, channels)` while `channels_first` corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + name: A string, the name of the layer. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + """ + layer = MaxPooling3D(pool_size=pool_size, strides=strides, + padding=padding, data_format=data_format, + name=name) + return layer.apply(inputs) + +# Aliases + +AvgPool2D = AveragePooling2D +MaxPool2D = MaxPooling2D +max_pool2d = max_pooling2d +avg_pool2d = average_pooling2d diff --git a/tensorflow/python/layers/pooling_test.py b/tensorflow/python/keras/legacy_tf_layers/pooling_test.py similarity index 99% rename from tensorflow/python/layers/pooling_test.py rename to tensorflow/python/keras/legacy_tf_layers/pooling_test.py index cf1fa1e6915..0fd63ed335f 100644 --- a/tensorflow/python/layers/pooling_test.py +++ b/tensorflow/python/keras/legacy_tf_layers/pooling_test.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import test_util -from tensorflow.python.layers import pooling as pooling_layers +from tensorflow.python.keras.legacy_tf_layers import pooling as pooling_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 5a944703af9..1021e45cbd4 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -17,577 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy +from tensorflow.python.keras.legacy_tf_layers import base -from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.keras import backend -from tensorflow.python.keras.engine import base_layer -from tensorflow.python.keras.mixed_precision.experimental import policy -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.training.tracking import base as trackable -from tensorflow.python.util import deprecation -from tensorflow.python.util import function_utils -from tensorflow.python.util import nest -from tensorflow.python.util import tf_contextlib -from tensorflow.python.util.tf_export import tf_export +InputSpec = base.InputSpec -# Avoid breaking users who directly import this symbol from this file. -# TODO(fchollet): remove this. -InputSpec = base_layer.InputSpec # pylint: disable=invalid-name - -_KERAS_STYLE_SCOPE = False - - -@tf_export(v1=['layers.experimental.keras_style_scope']) -@tf_contextlib.contextmanager -def keras_style_scope(): - """Use Keras-style variable management. - - All tf.layers and tf RNN cells created in this scope use Keras-style - variable management. Creating such layers with a scope= argument is - disallowed, and reuse=True is disallowed. - - The purpose of this scope is to allow users of existing layers to - slowly transition to a Keras layers API without breaking existing - functionality. - - One example of this is when using TensorFlow's RNN classes with Keras - Models or Networks. Because Keras models do not properly set variable - scopes, users of RNNs may either accidentally share scopes between two - different models, or get errors about variables that already exist. - - Example: - - ```python - class RNNModel(tf.keras.Model): - - def __init__(self, name): - super(RNNModel, self).__init__(name=name) - self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell( - [tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)]) - - def call(self, input, state): - return self.rnn(input, state) - - model_1 = RNNModel("model_1") - model_2 = RNNModel("model_2") - - # OK - output_1, next_state_1 = model_1(input, state) - # Raises an error about trying to create an already existing variable. - output_2, next_state_2 = model_2(input, state) - ``` - - The solution is to wrap the model construction and execution in a keras-style - scope: - - ```python - with keras_style_scope(): - model_1 = RNNModel("model_1") - model_2 = RNNModel("model_2") - - # model_1 and model_2 are guaranteed to create their own variables. - output_1, next_state_1 = model_1(input, state) - output_2, next_state_2 = model_2(input, state) - - assert len(model_1.weights) > 0 - assert len(model_2.weights) > 0 - assert(model_1.weights != model_2.weights) - ``` - - Yields: - A keras layer style scope. - """ - global _KERAS_STYLE_SCOPE - stack = _KERAS_STYLE_SCOPE - _KERAS_STYLE_SCOPE = True - try: - yield - finally: - _KERAS_STYLE_SCOPE = stack - - -@tf_export(v1=['layers.experimental.set_keras_style']) -def set_keras_style(): - """Use Keras-style variable management. - - All tf.layers and tf RNN cells created after keras style ha been enabled - use Keras-style variable management. Creating such layers with a - scope= argument is disallowed, and reuse=True is disallowed. - - The purpose of this function is to allow users of existing layers to - slowly transition to Keras layers API without breaking existing - functionality. - - For more details, see the documentation for `keras_style_scope`. - - Note, once keras style has been set, it is set globally for the entire - program and cannot be unset. - - Example: - - ```python - set_keras_style() - - model_1 = RNNModel(name="model_1") - model_2 = RNNModel(name="model_2") - - # model_1 and model_2 are guaranteed to create their own variables. - output_1, next_state_1 = model_1(input, state) - output_2, next_state_2 = model_2(input, state) - - assert len(model_1.weights) > 0 - assert len(model_2.weights) > 0 - assert(model_1.weights != model_2.weights) - ``` - """ - global _KERAS_STYLE_SCOPE - _KERAS_STYLE_SCOPE = True - - -def _is_in_keras_style_scope(): - global _KERAS_STYLE_SCOPE - return _KERAS_STYLE_SCOPE - - -@tf_export(v1=['layers.Layer']) -class Layer(base_layer.Layer): - """Base layer class. - - It is considered legacy, and we recommend the use of `tf.keras.layers.Layer` - instead. - - Arguments: - trainable: Boolean, whether the layer's variables should be trainable. - name: String name of the layer. - dtype: Default dtype of the layer's weights (default of `None` means use the - type of the first input). - - Read-only properties: - name: The name of the layer (string). - dtype: Default dtype of the layer's weights (default of `None` means use the - type of the first input). - trainable_variables: List of trainable variables. - non_trainable_variables: List of non-trainable variables. - variables: List of all variables of this layer, trainable and - non-trainable. - updates: List of update ops of this layer. - losses: List of losses added by this layer. - trainable_weights: List of variables to be included in backprop. - non_trainable_weights: List of variables that should not be - included in backprop. - weights: The concatenation of the lists trainable_weights and - non_trainable_weights (in this order). - - Mutable properties: - trainable: Whether the layer should be trained (boolean). - input_spec: Optional (list of) `InputSpec` object(s) specifying the - constraints on inputs that can be accepted by the layer. - """ - - def __init__(self, trainable=True, name=None, dtype=None, - **kwargs): - # For backwards compatibility, legacy layers do not use `ResourceVariable` - # by default. - self._use_resource_variables = False - scope = kwargs.pop('_scope', None) - self._reuse = kwargs.pop('_reuse', None) - - # Avoid an incorrect lint error - self._trainable_weights = [] - self.built = False - - if dtype is None: - # Indicates to infer dtype from inputs. When the V2 dtype behavior is - # enabled, Keras layers default their dtype to floatx instead, so we pass - # an "_infer" policy to keep the old V1 behavior. - dtype = policy.Policy('_infer') - - if 'autocast' not in kwargs: - kwargs['autocast'] = False - - super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype, - **kwargs) - - if _is_in_keras_style_scope(): - if scope is not None: - raise ValueError( - 'scope argument not allowed when keras style layers are enabled, ' - 'but saw: {}'.format(scope)) - if self._reuse is not None: - raise ValueError( - 'reuse argument not allowed when keras style layers are enabled, ' - 'but saw: {}'.format(self._reuse)) - self._keras_style = True - else: - self._keras_style = False - - self._call_has_scope_arg = 'scope' in self._call_fn_args - if scope: - with vs.variable_scope(scope) as captured_scope: - self._scope = captured_scope - else: - self._scope = None - self._current_scope = None - - # We no longer track graph in tf.layers layers. This property is only kept to - # maintain API backward compatibility. - @property - @deprecation.deprecated( - date=None, - instructions='Stop using this property because tf.layers layers no ' - 'longer track their graph.') - def graph(self): - if context.executing_eagerly(): - raise RuntimeError('Layer.graph not supported when executing eagerly.') - return None - - def _init_set_name(self, name): - # Determine layer name (non-unique). - if isinstance(name, vs.VariableScope): - base_name = name.name - self._name, _ = self._make_unique_name() - else: - base_name = name - self._name = name - if not name: - self._name, base_name = self._make_unique_name() - self._base_name = base_name - - def _make_unique_name(self, name_uid_map=None, avoid_names=None, - namespace='', zero_based=False): - base_name = base_layer.to_snake_case(self.__class__.__name__) - name = backend.unique_object_name( - base_name, - name_uid_map=name_uid_map, - avoid_names=avoid_names, - namespace=namespace, - zero_based=zero_based) - return (name, base_name) - - @property - def scope_name(self): - if not self._scope: - raise ValueError('No name available for layer scope because the layer "' + - self._name + '" has not been used yet. The scope name ' + - ' is determined the first time the layer instance is ' + - 'called. You must therefore call the layer before ' + - 'querying `scope_name`.') - return self._scope.name - - def add_loss(self, losses, inputs=None): - previous_losses_length = len(self._losses) - previous_callable_losses_length = len(self._callable_losses) - super(Layer, self).add_loss(losses, inputs=inputs) - if not context.executing_eagerly(): - # TODO(fchollet): deprecate collection below. - new_losses = self._losses[previous_losses_length:] - new_callable_losses = self._callable_losses[ - previous_callable_losses_length:] - for regularizer in new_callable_losses: - loss_tensor = regularizer() - if loss_tensor is not None: - new_losses.append(loss_tensor) - _add_elements_to_collection( - new_losses, - ops.GraphKeys.REGULARIZATION_LOSSES) - - def _name_scope(self): - """Determines op naming for the Layer.""" - if self._keras_style: - return super(Layer, self)._name_scope() - return self._current_scope.original_name_scope - - def _set_scope(self, scope=None): - if self._scope is None: - # If constructed with _scope=None, lazy setting of scope. - if self._reuse: - with vs.variable_scope( - scope if scope is not None else self._base_name) as captured_scope: - self._scope = captured_scope - else: - with vs.variable_scope( - scope, default_name=self._base_name) as captured_scope: - self._scope = captured_scope - - def add_weight(self, - name, - shape, - dtype=None, - initializer=None, - regularizer=None, - trainable=None, - constraint=None, - use_resource=None, - synchronization=vs.VariableSynchronization.AUTO, - aggregation=vs.VariableAggregation.NONE, - partitioner=None, - **kwargs): - """Adds a new variable to the layer, or gets an existing one; returns it. - - Arguments: - name: variable name. - shape: variable shape. - dtype: The type of the variable. Defaults to `self.dtype` or `float32`. - initializer: initializer instance (callable). - regularizer: regularizer instance (callable). - trainable: whether the variable should be part of the layer's - "trainable_variables" (e.g. variables, biases) - or "non_trainable_variables" (e.g. BatchNorm mean, stddev). - Note, if the current variable scope is marked as non-trainable - then this parameter is ignored and any added variables are also - marked as non-trainable. `trainable` defaults to `True` unless - `synchronization` is set to `ON_READ`. - constraint: constraint instance (callable). - use_resource: Whether to use `ResourceVariable`. - synchronization: Indicates when a distributed a variable will be - aggregated. Accepted values are constants defined in the class - `tf.VariableSynchronization`. By default the synchronization is set to - `AUTO` and the current `DistributionStrategy` chooses - when to synchronize. If `synchronization` is set to `ON_READ`, - `trainable` must not be set to `True`. - aggregation: Indicates how a distributed variable will be aggregated. - Accepted values are constants defined in the class - `tf.VariableAggregation`. - partitioner: (optional) partitioner instance (callable). If - provided, when the requested variable is created it will be split - into multiple partitions according to `partitioner`. In this case, - an instance of `PartitionedVariable` is returned. Available - partitioners include `tf.compat.v1.fixed_size_partitioner` and - `tf.compat.v1.variable_axis_size_partitioner`. For more details, see - the documentation of `tf.compat.v1.get_variable` and the "Variable - Partitioners and Sharding" section of the API guide. - **kwargs: Additional keyword arguments. - - Returns: - The created variable. Usually either a `Variable` or `ResourceVariable` - instance. If `partitioner` is not `None`, a `PartitionedVariable` - instance is returned. - - Raises: - RuntimeError: If called with partitioned variable regularization and - eager execution is enabled. - ValueError: When trainable has been set to True with synchronization - set as `ON_READ`. - """ - for kwarg in kwargs: - if kwarg != 'experimental_autocast': - raise TypeError('Unknown keyword argument:', kwarg) - if self._keras_style: - return super(Layer, self).add_weight( - name=name, - shape=shape, - dtype=dtype, - initializer=initializer, - regularizer=regularizer, - trainable=trainable and self.trainable, - constraint=constraint, - use_resource=use_resource, - synchronization=vs.VariableSynchronization.AUTO, - aggregation=vs.VariableAggregation.NONE, - partitioner=partitioner, - **kwargs) - - if synchronization == vs.VariableSynchronization.ON_READ: - if trainable: - raise ValueError( - 'Synchronization value can be set to ' - 'VariableSynchronization.ON_READ only for non-trainable variables. ' - 'You have specified trainable=True and ' - 'synchronization=VariableSynchronization.ON_READ.') - else: - # Set trainable to be false when variable is to be synced on read. - trainable = False - elif trainable is None: - trainable = True - - def _should_add_regularizer(variable, existing_variable_set): - if isinstance(variable, tf_variables.PartitionedVariable): - for var in variable: - if var in existing_variable_set: - return False - return True - else: - return variable not in existing_variable_set - - init_graph = None - if not context.executing_eagerly(): - default_graph = ops.get_default_graph() - if default_graph.building_function: - with ops.init_scope(): - # Retrieve the variables from the graph into which variables - # will be lifted; if initialization ops will be lifted into - # the eager context, then there is nothing to retrieve, since variable - # collections are not supported when eager execution is enabled. - if not context.executing_eagerly(): - init_graph = ops.get_default_graph() - existing_variables = set(tf_variables.global_variables()) - else: - # Initialization ops will not be lifted out of the default graph. - init_graph = default_graph - existing_variables = set(tf_variables.global_variables()) - - if dtype is None: - dtype = self.dtype or dtypes.float32 - - self._set_scope(None) - reuse = self.built or self._reuse - prev_len_trainable = len(self._trainable_weights) - with vs.variable_scope( - self._scope, reuse=reuse, auxiliary_name_scope=False) as scope: - self._current_scope = scope - with ops.name_scope(self._name_scope(), skip_on_eager=False): - use_resource = (use_resource or - self._use_resource_variables or - scope.use_resource) - if initializer is None: - initializer = scope.initializer - variable = super(Layer, self).add_weight( - name, - shape, - dtype=dtypes.as_dtype(dtype), - initializer=initializer, - trainable=trainable and self.trainable, - constraint=constraint, - partitioner=partitioner, - use_resource=use_resource, - synchronization=synchronization, - aggregation=aggregation, - getter=vs.get_variable, - **kwargs) - - if regularizer: - if (ops.executing_eagerly_outside_functions() - or _should_add_regularizer(variable, existing_variables)): - self._handle_weight_regularization(name, variable, regularizer) - - if init_graph is not None: - # Handle edge case where a custom getter has overridden `trainable`. - # There is one known occurrence of this, in unit test - # testBasicRNNCellNotTrainable in - # contrib.rnn.python.kernel_tests.core_rnn_cell_test - with init_graph.as_default(): - trainable_variables = tf_variables.trainable_variables() - if (trainable and self.trainable and - variable not in trainable_variables): - # A custom getter / variable scope overrode the trainable flag. - extra_trainable_vars = self._trainable_weights[prev_len_trainable:] - self._trainable_weights = self._trainable_weights[ - :prev_len_trainable] - self._non_trainable_weights += extra_trainable_vars - return variable - - def __call__(self, inputs, *args, **kwargs): - """Wraps `call`, applying pre- and post-processing steps. - - Arguments: - inputs: input tensor(s). - *args: additional positional arguments to be passed to `self.call`. - **kwargs: additional keyword arguments to be passed to `self.call`. - **Note**: kwarg `scope` is reserved for use by the layer. - - Returns: - Output tensor(s). - - Note: - - If the layer's `call` method takes a `scope` keyword argument, - this argument will be automatically set to the current variable scope. - - If the layer's `call` method takes a `mask` argument (as some Keras - layers do), its default value will be set to the mask generated - for `inputs` by the previous layer (if `input` did come from - a layer that generated a corresponding mask, i.e. if it came from - a Keras layer with masking support. - - Raises: - ValueError: if the layer's `call` method returns None (an invalid value). - """ - scope = kwargs.pop('scope', None) - - if self._keras_style: - if scope is not None: - raise ValueError( - 'scope argument not allowed when keras style layers are enabled, ' - 'but saw: {}'.format(scope)) - return super(Layer, self).__call__(inputs, *args, **kwargs) - - self._set_scope(scope) - - if self.built: - try: - # Some classes which inherit from Layer do not use its constructor, so - # rather than initializing to None we check for an AttributeError. - scope_context_manager = self._always_reuse_variable_scope - except AttributeError: - # From this point we will always set reuse=True, so create a "final" - # variable scope with this setting. We avoid re-creating variable scopes - # after this point as an optimization. - self._always_reuse_variable_scope = vs.variable_scope( - self._scope, reuse=True, auxiliary_name_scope=False) - scope_context_manager = self._always_reuse_variable_scope - else: - scope_context_manager = vs.variable_scope( - self._scope, reuse=self._reuse, auxiliary_name_scope=False) - - with scope_context_manager as scope: - self._current_scope = scope - - try: - call_has_scope_arg = self._call_has_scope_arg - except AttributeError: - self._call_fn_args = function_utils.fn_args(self.call) - self._call_has_scope_arg = 'scope' in self._call_fn_args - call_has_scope_arg = self._call_has_scope_arg - if call_has_scope_arg: - kwargs['scope'] = scope - - # Actually call layer - outputs = super(Layer, self).__call__(inputs, *args, **kwargs) - - if not context.executing_eagerly(): - # Update global default collections. - _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) - return outputs - - def __deepcopy__(self, memo): - no_copy = set(['_graph', '_thread_local', '_metrics_lock']) - shallow_copy = set(['_scope', '_always_reuse_variable_scope']) - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k in no_copy: - setattr(result, k, v) - elif k in shallow_copy: - setattr(result, k, copy.copy(v)) - elif base_layer.is_tensor_or_tensor_list(v): - setattr(result, k, v) - else: - setattr(result, k, copy.deepcopy(v, memo)) - return result - - def __setattr__(self, value, name): - # By-pass the automatic dependency tracking performed by the parent Layer. - super(trackable.Trackable, self).__setattr__(value, name) - - @property - def _is_legacy_layer(self): - """Used by keras to check compatibility. This should not be overridden.""" - return True - - -def _add_elements_to_collection(elements, collection_list): - if context.executing_eagerly(): - raise RuntimeError('Using collections from Layers not supported in Eager ' - 'mode. Tried to add %s to %s' % (elements, - collection_list)) - elements = nest.flatten(elements) - collection_list = nest.flatten(collection_list) - for name in collection_list: - collection = ops.get_collection_ref(name) - collection_set = {id(e) for e in collection} - for element in elements: - if id(element) not in collection_set: - collection.append(element) +keras_style_scope = base.keras_style_scope +set_keras_style = base.set_keras_style +Layer = base.Layer diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index f88934122fc..f3839facc8b 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -19,1439 +19,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras import layers as keras_layers -from tensorflow.python.layers import base -from tensorflow.python.ops import init_ops -from tensorflow.python.util import deprecation -from tensorflow.python.util.tf_export import tf_export - - -@tf_export(v1=['layers.Conv1D']) -class Conv1D(keras_layers.Conv1D, base.Layer): - """1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. If `use_bias` is True (and a `bias_initializer` is provided), - a bias vector is created and added to the outputs. Finally, if - `activation` is not `None`, it is applied to the outputs as well. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - """ - - def __init__(self, filters, - kernel_size, - strides=1, - padding='valid', - data_format='channels_last', - dilation_rate=1, - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - **kwargs): - super(Conv1D, self).__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, **kwargs) - - -@deprecation.deprecated( - date=None, - instructions='Use `tf.keras.layers.Conv1D` instead.') -@tf_export(v1=['layers.conv1d']) -def conv1d(inputs, - filters, - kernel_size, - strides=1, - padding='valid', - data_format='channels_last', - dilation_rate=1, - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - reuse=None): - """Functional interface for 1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. If `use_bias` is True (and a `bias_initializer` is provided), - a bias vector is created and added to the outputs. Finally, if - `activation` is not `None`, it is applied to the outputs as well. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = Conv1D( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - _reuse=reuse, - _scope=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.Conv2D']) -class Conv2D(keras_layers.Conv2D, base.Layer): - """2D convolution layer (e.g. spatial convolution over images). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. If `use_bias` is True (and a `bias_initializer` is provided), - a bias vector is created and added to the outputs. Finally, if - `activation` is not `None`, it is applied to the outputs as well. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - """ - - def __init__(self, filters, - kernel_size, - strides=(1, 1), - padding='valid', - data_format='channels_last', - dilation_rate=(1, 1), - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - **kwargs): - super(Conv2D, self).__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, **kwargs) - - -@deprecation.deprecated( - date=None, - instructions='Use `tf.keras.layers.Conv2D` instead.') -@tf_export(v1=['layers.conv2d']) -def conv2d(inputs, - filters, - kernel_size, - strides=(1, 1), - padding='valid', - data_format='channels_last', - dilation_rate=(1, 1), - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - reuse=None): - """Functional interface for the 2D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. If `use_bias` is True (and a `bias_initializer` is provided), - a bias vector is created and added to the outputs. Finally, if - `activation` is not `None`, it is applied to the outputs as well. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = Conv2D( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - _reuse=reuse, - _scope=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.Conv3D']) -class Conv3D(keras_layers.Conv3D, base.Layer): - """3D convolution layer (e.g. spatial convolution over volumes). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. If `use_bias` is True (and a `bias_initializer` is provided), - a bias vector is created and added to the outputs. Finally, if - `activation` is not `None`, it is applied to the outputs as well. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - """ - - def __init__(self, filters, - kernel_size, - strides=(1, 1, 1), - padding='valid', - data_format='channels_last', - dilation_rate=(1, 1, 1), - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - **kwargs): - super(Conv3D, self).__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, **kwargs) - - -@deprecation.deprecated( - date=None, - instructions='Use `tf.keras.layers.Conv3D` instead.') -@tf_export(v1=['layers.conv3d']) -def conv3d(inputs, - filters, - kernel_size, - strides=(1, 1, 1), - padding='valid', - data_format='channels_last', - dilation_rate=(1, 1, 1), - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - reuse=None): - """Functional interface for the 3D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. If `use_bias` is True (and a `bias_initializer` is provided), - a bias vector is created and added to the outputs. Finally, if - `activation` is not `None`, it is applied to the outputs as well. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = Conv3D( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - _reuse=reuse, - _scope=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.SeparableConv1D']) -class SeparableConv1D(keras_layers.SeparableConv1D, base.Layer): - """Depthwise separable 1D convolution. - - This layer performs a depthwise convolution that acts separately on - channels, followed by a pointwise convolution that mixes channels. - If `use_bias` is True and a bias initializer is provided, - it adds a bias vector to the output. - It then optionally applies an activation function to produce the final output. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: A single integer specifying the spatial - dimensions of the filters. - strides: A single integer specifying the strides - of the convolution. - Specifying any `stride` value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: A single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - depth_multiplier: The number of depthwise convolution output channels for - each input channel. The total number of depthwise convolution output - channels will be equal to `num_filters_in * depth_multiplier`. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - depthwise_initializer: An initializer for the depthwise convolution kernel. - pointwise_initializer: An initializer for the pointwise convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - depthwise_regularizer: Optional regularizer for the depthwise - convolution kernel. - pointwise_regularizer: Optional regularizer for the pointwise - convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - depthwise_constraint: Optional projection function to be applied to the - depthwise kernel after being updated by an `Optimizer` (e.g. used for - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - pointwise_constraint: Optional projection function to be applied to the - pointwise kernel after being updated by an `Optimizer`. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - """ - - def __init__(self, filters, - kernel_size, - strides=1, - padding='valid', - data_format='channels_last', - dilation_rate=1, - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer=None, - pointwise_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - **kwargs): - super(SeparableConv1D, self).__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - pointwise_initializer=pointwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - pointwise_regularizer=pointwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - pointwise_constraint=pointwise_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - **kwargs) - - -@tf_export(v1=['layers.SeparableConv2D']) -class SeparableConv2D(keras_layers.SeparableConv2D, base.Layer): - """Depthwise separable 2D convolution. - - This layer performs a depthwise convolution that acts separately on - channels, followed by a pointwise convolution that mixes channels. - If `use_bias` is True and a bias initializer is provided, - it adds a bias vector to the output. - It then optionally applies an activation function to produce the final output. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: A tuple or list of 2 integers specifying the spatial - dimensions of the filters. Can be a single integer to specify the same - value for all spatial dimensions. - strides: A tuple or list of 2 positive integers specifying the strides - of the convolution. Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any `stride` value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - depth_multiplier: The number of depthwise convolution output channels for - each input channel. The total number of depthwise convolution output - channels will be equal to `num_filters_in * depth_multiplier`. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - depthwise_initializer: An initializer for the depthwise convolution kernel. - pointwise_initializer: An initializer for the pointwise convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - depthwise_regularizer: Optional regularizer for the depthwise - convolution kernel. - pointwise_regularizer: Optional regularizer for the pointwise - convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - depthwise_constraint: Optional projection function to be applied to the - depthwise kernel after being updated by an `Optimizer` (e.g. used for - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - pointwise_constraint: Optional projection function to be applied to the - pointwise kernel after being updated by an `Optimizer`. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - """ - - def __init__(self, filters, - kernel_size, - strides=(1, 1), - padding='valid', - data_format='channels_last', - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer=None, - pointwise_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - **kwargs): - super(SeparableConv2D, self).__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - pointwise_initializer=pointwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - pointwise_regularizer=pointwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - pointwise_constraint=pointwise_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - **kwargs) - - -@deprecation.deprecated( - date=None, - instructions='Use `tf.keras.layers.SeparableConv1D` instead.') -@tf_export(v1=['layers.separable_conv1d']) -def separable_conv1d(inputs, - filters, - kernel_size, - strides=1, - padding='valid', - data_format='channels_last', - dilation_rate=1, - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer=None, - pointwise_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - reuse=None): - """Functional interface for the depthwise separable 1D convolution layer. - - This layer performs a depthwise convolution that acts separately on - channels, followed by a pointwise convolution that mixes channels. - If `use_bias` is True and a bias initializer is provided, - it adds a bias vector to the output. - It then optionally applies an activation function to produce the final output. - - Arguments: - inputs: Input tensor. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: A single integer specifying the spatial - dimensions of the filters. - strides: A single integer specifying the strides - of the convolution. - Specifying any `stride` value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: A single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - depth_multiplier: The number of depthwise convolution output channels for - each input channel. The total number of depthwise convolution output - channels will be equal to `num_filters_in * depth_multiplier`. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - depthwise_initializer: An initializer for the depthwise convolution kernel. - pointwise_initializer: An initializer for the pointwise convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - depthwise_regularizer: Optional regularizer for the depthwise - convolution kernel. - pointwise_regularizer: Optional regularizer for the pointwise - convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - depthwise_constraint: Optional projection function to be applied to the - depthwise kernel after being updated by an `Optimizer` (e.g. used for - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - pointwise_constraint: Optional projection function to be applied to the - pointwise kernel after being updated by an `Optimizer`. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = SeparableConv1D( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - pointwise_initializer=pointwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - pointwise_regularizer=pointwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - pointwise_constraint=pointwise_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - _reuse=reuse, - _scope=name) - return layer.apply(inputs) - - -@deprecation.deprecated( - date=None, - instructions='Use `tf.keras.layers.SeparableConv2D` instead.') -@tf_export(v1=['layers.separable_conv2d']) -def separable_conv2d(inputs, - filters, - kernel_size, - strides=(1, 1), - padding='valid', - data_format='channels_last', - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer=None, - pointwise_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - reuse=None): - """Functional interface for the depthwise separable 2D convolution layer. - - This layer performs a depthwise convolution that acts separately on - channels, followed by a pointwise convolution that mixes channels. - If `use_bias` is True and a bias initializer is provided, - it adds a bias vector to the output. - It then optionally applies an activation function to produce the final output. - - Arguments: - inputs: Input tensor. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: A tuple or list of 2 integers specifying the spatial - dimensions of the filters. Can be a single integer to specify the same - value for all spatial dimensions. - strides: A tuple or list of 2 positive integers specifying the strides - of the convolution. Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any `stride` value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - depth_multiplier: The number of depthwise convolution output channels for - each input channel. The total number of depthwise convolution output - channels will be equal to `num_filters_in * depth_multiplier`. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - depthwise_initializer: An initializer for the depthwise convolution kernel. - pointwise_initializer: An initializer for the pointwise convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - depthwise_regularizer: Optional regularizer for the depthwise - convolution kernel. - pointwise_regularizer: Optional regularizer for the pointwise - convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - depthwise_constraint: Optional projection function to be applied to the - depthwise kernel after being updated by an `Optimizer` (e.g. used for - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - pointwise_constraint: Optional projection function to be applied to the - pointwise kernel after being updated by an `Optimizer`. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = SeparableConv2D( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - pointwise_initializer=pointwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - pointwise_regularizer=pointwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - pointwise_constraint=pointwise_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - _reuse=reuse, - _scope=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.Conv2DTranspose']) -class Conv2DTranspose(keras_layers.Conv2DTranspose, base.Layer): - """Transposed 2D convolution layer (sometimes called 2D Deconvolution). - - The need for transposed convolutions generally arises - from the desire to use a transformation going in the opposite direction - of a normal convolution, i.e., from something that has the shape of the - output of some convolution to something that has the shape of its input - while maintaining a connectivity pattern that is compatible with - said convolution. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: A tuple or list of 2 positive integers specifying the spatial - dimensions of the filters. Can be a single integer to specify the same - value for all spatial dimensions. - strides: A tuple or list of 2 positive integers specifying the strides - of the convolution. Can be a single integer to specify the same value for - all spatial dimensions. - padding: one of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - """ - - def __init__(self, filters, - kernel_size, - strides=(1, 1), - padding='valid', - data_format='channels_last', - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - **kwargs): - super(Conv2DTranspose, self).__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - **kwargs) - - -@deprecation.deprecated( - date=None, - instructions='Use `tf.keras.layers.Conv2DTranspose` instead.') -@tf_export(v1=['layers.conv2d_transpose']) -def conv2d_transpose(inputs, - filters, - kernel_size, - strides=(1, 1), - padding='valid', - data_format='channels_last', - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - reuse=None): - """Functional interface for transposed 2D convolution layer. - - The need for transposed convolutions generally arises - from the desire to use a transformation going in the opposite direction - of a normal convolution, i.e., from something that has the shape of the - output of some convolution to something that has the shape of its input - while maintaining a connectivity pattern that is compatible with - said convolution. - - Arguments: - inputs: Input tensor. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: A tuple or list of 2 positive integers specifying the spatial - dimensions of the filters. Can be a single integer to specify the same - value for all spatial dimensions. - strides: A tuple or list of 2 positive integers specifying the strides - of the convolution. Can be a single integer to specify the same value for - all spatial dimensions. - padding: one of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - activation: Activation function. Set it to `None` to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If `None`, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = Conv2DTranspose( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - _reuse=reuse, - _scope=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.Conv3DTranspose']) -class Conv3DTranspose(keras_layers.Conv3DTranspose, base.Layer): - """Transposed 3D convolution layer (sometimes called 3D Deconvolution). - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for all spatial - dimensions. - strides: An integer or tuple/list of 3 integers, specifying the strides - of the convolution along the depth, height and width. - Can be a single integer to specify the same value for all spatial - dimensions. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - activation: Activation function. Set it to `None` to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If `None`, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - """ - - def __init__(self, - filters, - kernel_size, - strides=(1, 1, 1), - padding='valid', - data_format='channels_last', - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - **kwargs): - super(Conv3DTranspose, self).__init__( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - **kwargs) - - -@deprecation.deprecated( - date=None, - instructions='Use `tf.keras.layers.Conv3DTranspose` instead.') -@tf_export(v1=['layers.conv3d_transpose']) -def conv3d_transpose(inputs, - filters, - kernel_size, - strides=(1, 1, 1), - padding='valid', - data_format='channels_last', - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - reuse=None): - """Functional interface for transposed 3D convolution layer. - - Arguments: - inputs: Input tensor. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: A tuple or list of 3 positive integers specifying the spatial - dimensions of the filters. Can be a single integer to specify the same - value for all spatial dimensions. - strides: A tuple or list of 3 positive integers specifying the strides - of the convolution. Can be a single integer to specify the same value for - all spatial dimensions. - padding: one of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - activation: Activation function. Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: An initializer for the convolution kernel. - bias_initializer: An initializer for the bias vector. If None, the default - initializer will be used. - kernel_regularizer: Optional regularizer for the convolution kernel. - bias_regularizer: Optional regularizer for the bias vector. - activity_regularizer: Optional regularizer function for the output. - kernel_constraint: Optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: Optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = Conv3DTranspose( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - _reuse=reuse, - _scope=name) - return layer.apply(inputs) - +from tensorflow.python.keras.legacy_tf_layers import convolutional + +Conv1D = convolutional.Conv1D +conv1d = convolutional.conv1d +Conv2D = convolutional.Conv2D +conv2d = convolutional.conv2d +Conv3D = convolutional.Conv3D +conv3d = convolutional.conv3d +SeparableConv1D = convolutional.SeparableConv1D +SeparableConv2D = convolutional.SeparableConv2D +separable_conv1d = convolutional.separable_conv1d +separable_conv2d = convolutional.separable_conv2d +Conv2DTranspose = convolutional.Conv2DTranspose +conv2d_transpose = convolutional.conv2d_transpose +Conv3DTranspose = convolutional.Conv3DTranspose +conv3d_transpose = convolutional.conv3d_transpose # Aliases diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index dc9293d8072..b0c7d400ba6 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -21,316 +21,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - -from tensorflow.python.keras import layers as keras_layers -from tensorflow.python.layers import base -from tensorflow.python.ops import init_ops -from tensorflow.python.util import deprecation -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.keras.legacy_tf_layers import core -@tf_export(v1=['layers.Dense']) -class Dense(keras_layers.Dense, base.Layer): - """Densely-connected layer class. - - This layer implements the operation: - `outputs = activation(inputs * kernel + bias)` - Where `activation` is the activation function passed as the `activation` - argument (if not `None`), `kernel` is a weights matrix created by the layer, - and `bias` is a bias vector created by the layer - (only if `use_bias` is `True`). - - Arguments: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (callable). Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: Initializer function for the weight matrix. - If `None` (default), weights are initialized using the default - initializer used by `tf.compat.v1.get_variable`. - bias_initializer: Initializer function for the bias. - kernel_regularizer: Regularizer function for the weight matrix. - bias_regularizer: Regularizer function for the bias. - activity_regularizer: Regularizer function for the output. - kernel_constraint: An optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: An optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: String, the name of the layer. Layers with the same name will - share weights, but to avoid mistakes we require reuse=True in such cases. - _reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (callable). - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: Initializer instance (or name) for the kernel matrix. - bias_initializer: Initializer instance (or name) for the bias. - kernel_regularizer: Regularizer instance for the kernel matrix (callable) - bias_regularizer: Regularizer instance for the bias (callable). - activity_regularizer: Regularizer instance for the output (callable) - kernel_constraint: Constraint function for the kernel matrix. - bias_constraint: Constraint function for the bias. - kernel: Weight matrix (TensorFlow variable or tensor). - bias: Bias vector, if applicable (TensorFlow variable or tensor). - """ - - def __init__(self, units, - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - **kwargs): - super(Dense, self).__init__(units=units, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - **kwargs) - - -@deprecation.deprecated( - date=None, instructions='Use keras.layers.Dense instead.') -@tf_export(v1=['layers.dense']) -def dense( - inputs, units, - activation=None, - use_bias=True, - kernel_initializer=None, - bias_initializer=init_ops.zeros_initializer(), - kernel_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - bias_constraint=None, - trainable=True, - name=None, - reuse=None): - """Functional interface for the densely-connected layer. - - This layer implements the operation: - `outputs = activation(inputs * kernel + bias)` - where `activation` is the activation function passed as the `activation` - argument (if not `None`), `kernel` is a weights matrix created by the layer, - and `bias` is a bias vector created by the layer - (only if `use_bias` is `True`). - - Arguments: - inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (callable). Set it to None to maintain a - linear activation. - use_bias: Boolean, whether the layer uses a bias. - kernel_initializer: Initializer function for the weight matrix. - If `None` (default), weights are initialized using the default - initializer used by `tf.compat.v1.get_variable`. - bias_initializer: Initializer function for the bias. - kernel_regularizer: Regularizer function for the weight matrix. - bias_regularizer: Regularizer function for the bias. - activity_regularizer: Regularizer function for the output. - kernel_constraint: An optional projection function to be applied to the - kernel after being updated by an `Optimizer` (e.g. used to implement - norm constraints or value constraints for layer weights). The function - must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are - not safe to use when doing asynchronous distributed training. - bias_constraint: An optional projection function to be applied to the - bias after being updated by an `Optimizer`. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: String, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor the same shape as `inputs` except the last dimension is of - size `units`. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = Dense(units, - activation=activation, - use_bias=use_bias, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - kernel_constraint=kernel_constraint, - bias_constraint=bias_constraint, - trainable=trainable, - name=name, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -@tf_export(v1=['layers.Dropout']) -class Dropout(keras_layers.Dropout, base.Layer): - """Applies Dropout to the input. - - Dropout consists in randomly setting a fraction `rate` of input units to 0 - at each update during training time, which helps prevent overfitting. - The units that are kept are scaled by `1 / (1 - rate)`, so that their - sum is unchanged at training time and inference time. - - Arguments: - rate: The dropout rate, between 0 and 1. E.g. `rate=0.1` would drop out - 10% of input units. - noise_shape: 1D tensor of type `int32` representing the shape of the - binary dropout mask that will be multiplied with the input. - For instance, if your inputs have shape - `(batch_size, timesteps, features)`, and you want the dropout mask - to be the same for all timesteps, you can use - `noise_shape=[batch_size, 1, features]`. - seed: A Python integer. Used to create random seeds. See - `tf.compat.v1.set_random_seed`. - for behavior. - name: The name of the layer (string). - """ - - def __init__(self, rate=0.5, - noise_shape=None, - seed=None, - name=None, - **kwargs): - super(Dropout, self).__init__(rate=rate, - noise_shape=noise_shape, - seed=seed, - name=name, - **kwargs) - - def call(self, inputs, training=False): - return super(Dropout, self).call(inputs, training=training) - - -@deprecation.deprecated( - date=None, - instructions='Use keras.layers.dropout instead.') -@tf_export(v1=['layers.dropout']) -def dropout(inputs, - rate=0.5, - noise_shape=None, - seed=None, - training=False, - name=None): - """Applies Dropout to the input. - - Dropout consists in randomly setting a fraction `rate` of input units to 0 - at each update during training time, which helps prevent overfitting. - The units that are kept are scaled by `1 / (1 - rate)`, so that their - sum is unchanged at training time and inference time. - - Arguments: - inputs: Tensor input. - rate: The dropout rate, between 0 and 1. E.g. "rate=0.1" would drop out - 10% of input units. - noise_shape: 1D tensor of type `int32` representing the shape of the - binary dropout mask that will be multiplied with the input. - For instance, if your inputs have shape - `(batch_size, timesteps, features)`, and you want the dropout mask - to be the same for all timesteps, you can use - `noise_shape=[batch_size, 1, features]`. - seed: A Python integer. Used to create random seeds. See - `tf.compat.v1.set_random_seed` - for behavior. - training: Either a Python boolean, or a TensorFlow boolean scalar tensor - (e.g. a placeholder). Whether to return the output in training mode - (apply dropout) or in inference mode (return the input untouched). - name: The name of the layer (string). - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = Dropout(rate, noise_shape=noise_shape, seed=seed, name=name) - return layer.apply(inputs, training=training) - - -@tf_export(v1=['layers.Flatten']) -class Flatten(keras_layers.Flatten, base.Layer): - """Flattens an input tensor while preserving the batch axis (axis 0). - - Arguments: - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - - Examples: - - ``` - x = tf.compat.v1.placeholder(shape=(None, 4, 4), dtype='float32') - y = Flatten()(x) - # now `y` has shape `(None, 16)` - - x = tf.compat.v1.placeholder(shape=(None, 3, None), dtype='float32') - y = Flatten()(x) - # now `y` has shape `(None, None)` - ``` - """ - pass - - -@deprecation.deprecated( - date=None, - instructions='Use keras.layers.Flatten instead.') -@tf_export(v1=['layers.flatten']) -def flatten(inputs, name=None, data_format='channels_last'): - """Flattens an input tensor while preserving the batch axis (axis 0). - - Arguments: - inputs: Tensor input. - name: The name of the layer (string). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - Returns: - Reshaped tensor. - - Examples: - - ``` - x = tf.compat.v1.placeholder(shape=(None, 4, 4), dtype='float32') - y = flatten(x) - # now `y` has shape `(None, 16)` - - x = tf.compat.v1.placeholder(shape=(None, 3, None), dtype='float32') - y = flatten(x) - # now `y` has shape `(None, None)` - ``` - """ - layer = Flatten(name=name, data_format=data_format) - return layer.apply(inputs) - +Dense = core.Dense +dense = core.dense +Dropout = core.Dropout +dropout = core.dropout +Flatten = core.Flatten +flatten = core.flatten # Aliases diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 98f042637d0..04ab985058d 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -20,322 +20,11 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras.layers import normalization as keras_normalization -from tensorflow.python.layers import base -from tensorflow.python.ops import init_ops -from tensorflow.python.util import deprecation -from tensorflow.python.util.tf_export import tf_export - - -@tf_export(v1=['layers.BatchNormalization']) -class BatchNormalization(keras_normalization.BatchNormalization, base.Layer): - """Batch Normalization layer from (Ioffe et al., 2015). - - Keras APIs handle BatchNormalization updates to the moving_mean and - moving_variance as part of their `fit()` and `evaluate()` loops. However, if a - custom training loop is used with an instance of `Model`, these updates need - to be explicitly included. Here's a simple example of how it can be done: - - ```python - # model is an instance of Model that contains BatchNormalization layer. - update_ops = model.get_updates_for(None) + model.get_updates_for(features) - train_op = optimizer.minimize(loss) - train_op = tf.group([train_op, update_ops]) - ``` - - Arguments: - axis: An `int` or list of `int`, the axis or axes that should be normalized, - typically the features axis/axes. For instance, after a `Conv2D` layer - with `data_format="channels_first"`, set `axis=1`. If a list of axes is - provided, each axis in `axis` will be normalized - simultaneously. Default is `-1` which uses the last axis. Note: when - using multi-axis batch norm, the `beta`, `gamma`, `moving_mean`, and - `moving_variance` variables are the same rank as the input Tensor, - with dimension size 1 in all reduced (non-axis) dimensions). - momentum: Momentum for the moving average. - epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. If False, `beta` - is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the - next layer is linear (also e.g. `nn.relu`), this can be disabled since the - scaling can be done by the next layer. - beta_initializer: Initializer for the beta weight. - gamma_initializer: Initializer for the gamma weight. - moving_mean_initializer: Initializer for the moving mean. - moving_variance_initializer: Initializer for the moving variance. - beta_regularizer: Optional regularizer for the beta weight. - gamma_regularizer: Optional regularizer for the gamma weight. - beta_constraint: An optional projection function to be applied to the `beta` - weight after being updated by an `Optimizer` (e.g. used to implement norm - constraints or value constraints for layer weights). The function must - take as input the unprojected variable and must return the projected - variable (which must have the same shape). Constraints are not safe to use - when doing asynchronous distributed training. - gamma_constraint: An optional projection function to be applied to the - `gamma` weight after being updated by an `Optimizer`. - renorm: Whether to use Batch Renormalization (Ioffe, 2017). This adds extra - variables during training. The inference is the same for either value of - this parameter. - renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to - scalar `Tensors` used to clip the renorm correction. The correction `(r, - d)` is used as `corrected_value = normalized_value * r + d`, with `r` - clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, - dmax are set to inf, 0, inf, respectively. - renorm_momentum: Momentum used to update the moving means and standard - deviations with renorm. Unlike `momentum`, this affects training and - should be neither too small (which would add noise) nor too large (which - would give stale estimates). Note that `momentum` is still applied to get - the means and variances for inference. - fused: if `None` or `True`, use a faster, fused implementation if possible. - If `False`, use the system recommended implementation. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, - which means batch normalization is performed across the whole batch. When - `virtual_batch_size` is not `None`, instead perform "Ghost Batch - Normalization", which creates virtual sub-batches which are each - normalized separately (with shared gamma, beta, and moving statistics). - Must divide the actual batch size during execution. - adjustment: A function taking the `Tensor` containing the (dynamic) shape of - the input tensor and returning a pair (scale, bias) to apply to the - normalized values (before gamma and beta), only during training. For - example, if axis==-1, - `adjustment = lambda shape: ( - tf.random.uniform(shape[-1:], 0.93, 1.07), - tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized - value by up to 7% up or down, then shift the result by up to 0.1 - (with independent scaling and bias for each feature but shared - across all examples), and finally apply gamma and/or beta. If - `None`, no adjustment is applied. Cannot be specified if - virtual_batch_size is specified. - name: A string, the name of the layer. - References: - Batch Normalization - Accelerating Deep Network Training by Reducing - Internal Covariate Shift: - [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) - ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) - Batch Renormalization - Towards Reducing Minibatch Dependence in - Batch-Normalized Models: - [Ioffe, - 2017](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models) - ([pdf](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models.pdf)) - """ - - def __init__(self, - axis=-1, - momentum=0.99, - epsilon=1e-3, - center=True, - scale=True, - beta_initializer=init_ops.zeros_initializer(), - gamma_initializer=init_ops.ones_initializer(), - moving_mean_initializer=init_ops.zeros_initializer(), - moving_variance_initializer=init_ops.ones_initializer(), - beta_regularizer=None, - gamma_regularizer=None, - beta_constraint=None, - gamma_constraint=None, - renorm=False, - renorm_clipping=None, - renorm_momentum=0.99, - fused=None, - trainable=True, - virtual_batch_size=None, - adjustment=None, - name=None, - **kwargs): - super(BatchNormalization, self).__init__( - axis=axis, - momentum=momentum, - epsilon=epsilon, - center=center, - scale=scale, - beta_initializer=beta_initializer, - gamma_initializer=gamma_initializer, - moving_mean_initializer=moving_mean_initializer, - moving_variance_initializer=moving_variance_initializer, - beta_regularizer=beta_regularizer, - gamma_regularizer=gamma_regularizer, - beta_constraint=beta_constraint, - gamma_constraint=gamma_constraint, - renorm=renorm, - renorm_clipping=renorm_clipping, - renorm_momentum=renorm_momentum, - fused=fused, - trainable=trainable, - virtual_batch_size=virtual_batch_size, - adjustment=adjustment, - name=name, - **kwargs) - - def call(self, inputs, training=False): - return super(BatchNormalization, self).call(inputs, training=training) - - -@deprecation.deprecated( - date=None, instructions='Use keras.layers.BatchNormalization instead. In ' - 'particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not ' - 'be used (consult the `tf.keras.layers.BatchNormalization` ' - 'documentation).') -@tf_export(v1=['layers.batch_normalization']) -def batch_normalization(inputs, - axis=-1, - momentum=0.99, - epsilon=1e-3, - center=True, - scale=True, - beta_initializer=init_ops.zeros_initializer(), - gamma_initializer=init_ops.ones_initializer(), - moving_mean_initializer=init_ops.zeros_initializer(), - moving_variance_initializer=init_ops.ones_initializer(), - beta_regularizer=None, - gamma_regularizer=None, - beta_constraint=None, - gamma_constraint=None, - training=False, - trainable=True, - name=None, - reuse=None, - renorm=False, - renorm_clipping=None, - renorm_momentum=0.99, - fused=None, - virtual_batch_size=None, - adjustment=None): - """Functional interface for the batch normalization layer from_config(Ioffe et al., 2015). - - Note: when training, the moving_mean and moving_variance need to be updated. - By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they - need to be executed alongside the `train_op`. Also, be sure to add any - batch_normalization ops before getting the update_ops collection. Otherwise, - update_ops will be empty, and training/inference will not work properly. For - example: - - ```python - x_norm = tf.compat.v1.layers.batch_normalization(x, training=training) - - # ... - - update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS) - train_op = optimizer.minimize(loss) - train_op = tf.group([train_op, update_ops]) - ``` - - Arguments: - inputs: Tensor input. - axis: An `int`, the axis that should be normalized (typically the features - axis). For instance, after a `Convolution2D` layer with - `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. - momentum: Momentum for the moving average. - epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. If False, `beta` - is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the - next layer is linear (also e.g. `nn.relu`), this can be disabled since the - scaling can be done by the next layer. - beta_initializer: Initializer for the beta weight. - gamma_initializer: Initializer for the gamma weight. - moving_mean_initializer: Initializer for the moving mean. - moving_variance_initializer: Initializer for the moving variance. - beta_regularizer: Optional regularizer for the beta weight. - gamma_regularizer: Optional regularizer for the gamma weight. - beta_constraint: An optional projection function to be applied to the `beta` - weight after being updated by an `Optimizer` (e.g. used to implement norm - constraints or value constraints for layer weights). The function must - take as input the unprojected variable and must return the projected - variable (which must have the same shape). Constraints are not safe to use - when doing asynchronous distributed training. - gamma_constraint: An optional projection function to be applied to the - `gamma` weight after being updated by an `Optimizer`. - training: Either a Python boolean, or a TensorFlow boolean scalar tensor - (e.g. a placeholder). Whether to return the output in training mode - (normalized with statistics of the current batch) or in inference mode - (normalized with moving statistics). **NOTE**: make sure to set this - parameter correctly, or else your training/inference will not work - properly. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - name: String, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer by the same - name. - renorm: Whether to use Batch Renormalization (Ioffe, 2017). This adds extra - variables during training. The inference is the same for either value of - this parameter. - renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to - scalar `Tensors` used to clip the renorm correction. The correction `(r, - d)` is used as `corrected_value = normalized_value * r + d`, with `r` - clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, - dmax are set to inf, 0, inf, respectively. - renorm_momentum: Momentum used to update the moving means and standard - deviations with renorm. Unlike `momentum`, this affects training and - should be neither too small (which would add noise) nor too large (which - would give stale estimates). Note that `momentum` is still applied to get - the means and variances for inference. - fused: if `None` or `True`, use a faster, fused implementation if possible. - If `False`, use the system recommended implementation. - virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, - which means batch normalization is performed across the whole batch. When - `virtual_batch_size` is not `None`, instead perform "Ghost Batch - Normalization", which creates virtual sub-batches which are each - normalized separately (with shared gamma, beta, and moving statistics). - Must divide the actual batch size during execution. - adjustment: A function taking the `Tensor` containing the (dynamic) shape of - the input tensor and returning a pair (scale, bias) to apply to the - normalized values (before gamma and beta), only during training. For - example, if axis==-1, - `adjustment = lambda shape: ( - tf.random.uniform(shape[-1:], 0.93, 1.07), - tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized - value by up to 7% up or down, then shift the result by up to 0.1 - (with independent scaling and bias for each feature but shared - across all examples), and finally apply gamma and/or beta. If - `None`, no adjustment is applied. Cannot be specified if - virtual_batch_size is specified. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - References: - Batch Normalization - Accelerating Deep Network Training by Reducing - Internal Covariate Shift: - [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) - ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) - Batch Renormalization - Towards Reducing Minibatch Dependence in - Batch-Normalized Models: - [Ioffe, - 2017](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models) - ([pdf](http://papers.nips.cc/paper/6790-batch-renormalization-towards-reducing-minibatch-dependence-in-batch-normalized-models.pdf)) - """ - layer = BatchNormalization( - axis=axis, - momentum=momentum, - epsilon=epsilon, - center=center, - scale=scale, - beta_initializer=beta_initializer, - gamma_initializer=gamma_initializer, - moving_mean_initializer=moving_mean_initializer, - moving_variance_initializer=moving_variance_initializer, - beta_regularizer=beta_regularizer, - gamma_regularizer=gamma_regularizer, - beta_constraint=beta_constraint, - gamma_constraint=gamma_constraint, - renorm=renorm, - renorm_clipping=renorm_clipping, - renorm_momentum=renorm_momentum, - fused=fused, - trainable=trainable, - virtual_batch_size=virtual_batch_size, - adjustment=adjustment, - name=name, - _reuse=reuse, - _scope=name) - return layer.apply(inputs, training=training) +from tensorflow.python.keras.legacy_tf_layers import normalization +BatchNormalization = normalization.BatchNormalization +batch_normalization = normalization.batch_normalization # Aliases BatchNorm = BatchNormalization diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py index 2dbdc099742..5737f1ff09e 100644 --- a/tensorflow/python/layers/pooling.py +++ b/tensorflow/python/layers/pooling.py @@ -19,448 +19,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras import layers as keras_layers -from tensorflow.python.layers import base -from tensorflow.python.util import deprecation -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.keras.legacy_tf_layers import pooling -@tf_export(v1=['layers.AveragePooling1D']) -class AveragePooling1D(keras_layers.AveragePooling1D, base.Layer): - """Average Pooling layer for 1D inputs. - - Arguments: - pool_size: An integer or tuple/list of a single integer, - representing the size of the pooling window. - strides: An integer or tuple/list of a single integer, specifying the - strides of the pooling operation. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - name: A string, the name of the layer. - """ - - def __init__(self, pool_size, strides, - padding='valid', data_format='channels_last', - name=None, **kwargs): - if strides is None: - raise ValueError('Argument `strides` must not be None.') - super(AveragePooling1D, self).__init__( - pool_size=pool_size, - strides=strides, - padding=padding, - data_format=data_format, - name=name, - **kwargs) - - -@deprecation.deprecated( - date=None, instructions='Use keras.layers.AveragePooling1D instead.') -@tf_export(v1=['layers.average_pooling1d']) -def average_pooling1d(inputs, pool_size, strides, - padding='valid', data_format='channels_last', - name=None): - """Average Pooling layer for 1D inputs. - - Arguments: - inputs: The tensor over which to pool. Must have rank 3. - pool_size: An integer or tuple/list of a single integer, - representing the size of the pooling window. - strides: An integer or tuple/list of a single integer, specifying the - strides of the pooling operation. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - name: A string, the name of the layer. - - Returns: - The output tensor, of rank 3. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = AveragePooling1D(pool_size=pool_size, - strides=strides, - padding=padding, - data_format=data_format, - name=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.MaxPooling1D']) -class MaxPooling1D(keras_layers.MaxPooling1D, base.Layer): - """Max Pooling layer for 1D inputs. - - Arguments: - pool_size: An integer or tuple/list of a single integer, - representing the size of the pooling window. - strides: An integer or tuple/list of a single integer, specifying the - strides of the pooling operation. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - name: A string, the name of the layer. - """ - - def __init__(self, pool_size, strides, - padding='valid', data_format='channels_last', - name=None, **kwargs): - if strides is None: - raise ValueError('Argument `strides` must not be None.') - super(MaxPooling1D, self).__init__( - pool_size=pool_size, - strides=strides, - padding=padding, - data_format=data_format, - name=name, - **kwargs) - - -@deprecation.deprecated( - date=None, instructions='Use keras.layers.MaxPooling1D instead.') -@tf_export(v1=['layers.max_pooling1d']) -def max_pooling1d(inputs, pool_size, strides, - padding='valid', data_format='channels_last', - name=None): - """Max Pooling layer for 1D inputs. - - Arguments: - inputs: The tensor over which to pool. Must have rank 3. - pool_size: An integer or tuple/list of a single integer, - representing the size of the pooling window. - strides: An integer or tuple/list of a single integer, specifying the - strides of the pooling operation. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - name: A string, the name of the layer. - - Returns: - The output tensor, of rank 3. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = MaxPooling1D(pool_size=pool_size, - strides=strides, - padding=padding, - data_format=data_format, - name=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.AveragePooling2D']) -class AveragePooling2D(keras_layers.AveragePooling2D, base.Layer): - """Average pooling layer for 2D inputs (e.g. images). - - Arguments: - pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width) - specifying the size of the pooling window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the pooling operation. - Can be a single integer to specify the same value for - all spatial dimensions. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string. The ordering of the dimensions in the inputs. - `channels_last` (default) and `channels_first` are supported. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - name: A string, the name of the layer. - """ - - def __init__(self, pool_size, strides, - padding='valid', data_format='channels_last', - name=None, **kwargs): - if strides is None: - raise ValueError('Argument `strides` must not be None.') - super(AveragePooling2D, self).__init__( - pool_size=pool_size, strides=strides, - padding=padding, data_format=data_format, name=name, **kwargs) - - -@deprecation.deprecated( - date=None, instructions='Use keras.layers.AveragePooling2D instead.') -@tf_export(v1=['layers.average_pooling2d']) -def average_pooling2d(inputs, - pool_size, strides, - padding='valid', data_format='channels_last', - name=None): - """Average pooling layer for 2D inputs (e.g. images). - - Arguments: - inputs: The tensor over which to pool. Must have rank 4. - pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width) - specifying the size of the pooling window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the pooling operation. - Can be a single integer to specify the same value for - all spatial dimensions. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string. The ordering of the dimensions in the inputs. - `channels_last` (default) and `channels_first` are supported. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - name: A string, the name of the layer. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = AveragePooling2D(pool_size=pool_size, strides=strides, - padding=padding, data_format=data_format, - name=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.MaxPooling2D']) -class MaxPooling2D(keras_layers.MaxPooling2D, base.Layer): - """Max pooling layer for 2D inputs (e.g. images). - - Arguments: - pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width) - specifying the size of the pooling window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the pooling operation. - Can be a single integer to specify the same value for - all spatial dimensions. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string. The ordering of the dimensions in the inputs. - `channels_last` (default) and `channels_first` are supported. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - name: A string, the name of the layer. - """ - - def __init__(self, pool_size, strides, - padding='valid', data_format='channels_last', - name=None, **kwargs): - if strides is None: - raise ValueError('Argument `strides` must not be None.') - super(MaxPooling2D, self).__init__( - pool_size=pool_size, strides=strides, - padding=padding, data_format=data_format, name=name, **kwargs) - - -@deprecation.deprecated( - date=None, instructions='Use keras.layers.MaxPooling2D instead.') -@tf_export(v1=['layers.max_pooling2d']) -def max_pooling2d(inputs, - pool_size, strides, - padding='valid', data_format='channels_last', - name=None): - """Max pooling layer for 2D inputs (e.g. images). - - Arguments: - inputs: The tensor over which to pool. Must have rank 4. - pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width) - specifying the size of the pooling window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the pooling operation. - Can be a single integer to specify the same value for - all spatial dimensions. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string. The ordering of the dimensions in the inputs. - `channels_last` (default) and `channels_first` are supported. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - name: A string, the name of the layer. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = MaxPooling2D(pool_size=pool_size, strides=strides, - padding=padding, data_format=data_format, - name=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.AveragePooling3D']) -class AveragePooling3D(keras_layers.AveragePooling3D, base.Layer): - """Average pooling layer for 3D inputs (e.g. volumes). - - Arguments: - pool_size: An integer or tuple/list of 3 integers: - (pool_depth, pool_height, pool_width) - specifying the size of the pooling window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the pooling operation. - Can be a single integer to specify the same value for - all spatial dimensions. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string. The ordering of the dimensions in the inputs. - `channels_last` (default) and `channels_first` are supported. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - name: A string, the name of the layer. - """ - - def __init__(self, pool_size, strides, - padding='valid', data_format='channels_last', - name=None, **kwargs): - if strides is None: - raise ValueError('Argument `strides` must not be None.') - super(AveragePooling3D, self).__init__( - pool_size=pool_size, strides=strides, - padding=padding, data_format=data_format, name=name, **kwargs) - - -@deprecation.deprecated( - date=None, instructions='Use keras.layers.AveragePooling3D instead.') -@tf_export(v1=['layers.average_pooling3d']) -def average_pooling3d(inputs, - pool_size, strides, - padding='valid', data_format='channels_last', - name=None): - """Average pooling layer for 3D inputs (e.g. volumes). - - Arguments: - inputs: The tensor over which to pool. Must have rank 5. - pool_size: An integer or tuple/list of 3 integers: - (pool_depth, pool_height, pool_width) - specifying the size of the pooling window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the pooling operation. - Can be a single integer to specify the same value for - all spatial dimensions. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string. The ordering of the dimensions in the inputs. - `channels_last` (default) and `channels_first` are supported. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - name: A string, the name of the layer. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = AveragePooling3D(pool_size=pool_size, strides=strides, - padding=padding, data_format=data_format, - name=name) - return layer.apply(inputs) - - -@tf_export(v1=['layers.MaxPooling3D']) -class MaxPooling3D(keras_layers.MaxPooling3D, base.Layer): - """Max pooling layer for 3D inputs (e.g. volumes). - - Arguments: - pool_size: An integer or tuple/list of 3 integers: - (pool_depth, pool_height, pool_width) - specifying the size of the pooling window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the pooling operation. - Can be a single integer to specify the same value for - all spatial dimensions. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string. The ordering of the dimensions in the inputs. - `channels_last` (default) and `channels_first` are supported. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - name: A string, the name of the layer. - """ - - def __init__(self, pool_size, strides, - padding='valid', data_format='channels_last', - name=None, **kwargs): - if strides is None: - raise ValueError('Argument `strides` must not be None.') - super(MaxPooling3D, self).__init__( - pool_size=pool_size, strides=strides, - padding=padding, data_format=data_format, name=name, **kwargs) - - -@deprecation.deprecated( - date=None, instructions='Use keras.layers.MaxPooling3D instead.') -@tf_export(v1=['layers.max_pooling3d']) -def max_pooling3d(inputs, - pool_size, strides, - padding='valid', data_format='channels_last', - name=None): - """Max pooling layer for 3D inputs (e.g. - - volumes). - - Arguments: - inputs: The tensor over which to pool. Must have rank 5. - pool_size: An integer or tuple/list of 3 integers: (pool_depth, pool_height, - pool_width) specifying the size of the pooling window. Can be a single - integer to specify the same value for all spatial dimensions. - strides: An integer or tuple/list of 3 integers, specifying the strides of - the pooling operation. Can be a single integer to specify the same value - for all spatial dimensions. - padding: A string. The padding method, either 'valid' or 'same'. - Case-insensitive. - data_format: A string. The ordering of the dimensions in the inputs. - `channels_last` (default) and `channels_first` are supported. - `channels_last` corresponds to inputs with shape `(batch, depth, height, - width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - name: A string, the name of the layer. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - """ - layer = MaxPooling3D(pool_size=pool_size, strides=strides, - padding=padding, data_format=data_format, - name=name) - return layer.apply(inputs) +AveragePooling1D = pooling.AveragePooling1D +average_pooling1d = pooling.average_pooling1d +MaxPooling1D = pooling.MaxPooling1D +max_pooling1d = pooling.max_pooling1d +AveragePooling2D = pooling.AveragePooling2D +average_pooling2d = pooling.average_pooling2d +MaxPooling2D = pooling.MaxPooling2D +max_pooling2d = pooling.max_pooling2d +AveragePooling3D = pooling.AveragePooling3D +average_pooling3d = pooling.average_pooling3d +MaxPooling3D = pooling.MaxPooling3D +max_pooling3d = pooling.max_pooling3d # Aliases diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt index e862c75e524..f99dad33a6d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.AveragePooling1D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt index 690decf44b3..eb688a9c676 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.AveragePooling2D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt index 1eb8e3fdb1b..20ee5a52952 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.AveragePooling3D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt index 366e19a6dd7..e910faef781 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.BatchNormalization" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt index 178fb1be257..5ff802f6e48 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.Conv1D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt index 6a518eb565d..cd98e2773ba 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.Conv2DTranspose" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt index abd6d7f6e68..a65847cfda5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.Conv2D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt index 075f164245f..3ea91c6fb66 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.Conv3DTranspose" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt index 271f421ad4c..d39b663299e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.Conv3D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt index daea01b2942..fec6b718cd2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.layers.Dense" tf_class { - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt index 80590162f11..edc81f739e6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.layers.Dropout" tf_class { - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt index 96a12b73633..3cc78accc78 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.layers.Flatten" tf_class { - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt index fe594e836d1..2b65efadcef 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.layers.Layer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt index 8a0bcec6740..c05d1d10329 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.MaxPooling1D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt index ab3465609bf..a0e5e6fb050 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.MaxPooling2D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt index 1a1400e838e..788c181a96f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt @@ -1,9 +1,9 @@ path: "tensorflow.layers.MaxPooling3D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt index 2fde0d8f59a..920dcf0f747 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.SeparableConv1D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt index cf628d0f684..33bb2ee8785 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt @@ -1,10 +1,10 @@ path: "tensorflow.layers.SeparableConv2D" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt index 2a5ba923f28..a143468c615 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-t-f-lite-l-s-t-m-cell.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt index b09694a7e4c..fd240a31637 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.experimental.nn.-tf-lite-r-n-n-cell.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index d59011dcf16..02ec119a24a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index 6b1e0239aa5..185bfa99489 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 914a97a86d9..102a2266f5a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index 8fb43eac9a7..bb6bde99e53 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index b699722de26..832ec6f6be6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index f2787db8f50..6f471d3f811 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index 5a7770783f8..48d17d35fbe 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.nn.rnn_cell.MultiRNNCell" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index 0ee22729a11..5c428f658c9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.nn.rnn_cell.RNNCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index 64054d1581e..629d73640f3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: ""