diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 5d57f1d9b93..0d62a32b1fe 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -53,6 +53,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.layers import core as core_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops @@ -1404,7 +1405,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._run(fn, 10000) # TODO(b/157587712): Move to keras when benchmarks are setup. - def benchmark_tf_keras_layer_call(self): + def benchmark_tf_keras_layer_call_overhead(self): class OnlyOverheadLayer(base_layer.Layer): @@ -1419,6 +1420,18 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._run(fn, 10000) + # TODO(b/157587712): Move to keras when benchmarks are setup. + def benchmark_tf_keras_dense_overhead(self): + + layer = core_layers.Dense(1) + x = ops.convert_to_tensor([[1.]]) + layer(x) # Warmup call to `build` layer. + + def fn(): + layer(x) + + self._run(fn, 10000) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index e817d56f619..c7d25f31d73 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -293,12 +293,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # are only applicable to input layers: do not pass these keywords # to non-input layers. allowed_kwargs = { - 'input_shape', - 'batch_input_shape', - 'batch_size', - 'weights', - 'activity_regularizer', - 'autocast' + 'input_dim', 'input_shape', 'batch_input_shape', 'batch_size', + 'weights', 'activity_regularizer', 'autocast' } # Validate optional keyword arguments. generic_utils.validate_kwargs(kwargs, allowed_kwargs) @@ -323,7 +319,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self.supports_masking = False self._init_set_name(name) - self._activity_regularizer = kwargs.pop('activity_regularizer', None) + self._activity_regularizer = regularizers.get( + kwargs.pop('activity_regularizer', None)) self._maybe_create_attribute('_trainable_weights', []) self._maybe_create_attribute('_non_trainable_weights', []) self._updates = [] @@ -370,6 +367,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._dynamic = dynamic # Manage input shape information if passed. + if 'input_dim' in kwargs and 'input_shape' not in kwargs: + # Backwards compatibility: alias 'input_dim' to 'input_shape'. + kwargs['input_shape'] = (kwargs['input_dim'],) if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: # In this case we will later create an input layer # to insert before the current layer @@ -530,7 +530,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): dtype = dtypes.as_dtype(dtype) if self._dtype_policy.variable_dtype is None: # The policy is "_infer", so we infer the policy from the variable dtype. - self._dtype_policy = policy.Policy(dtype.base_dtype.name) + self._set_dtype_policy(policy.Policy(dtype.base_dtype.name)) initializer = initializers.get(initializer) regularizer = regularizers.get(regularizer) constraint = constraints.get(constraint) @@ -2202,7 +2202,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): @_dtype.setter def _dtype(self, value): value = dtypes.as_dtype(value).name - self._dtype_policy = policy.Policy(value) + self._set_dtype_policy(policy.Policy(value)) def _name_scope(self): if not tf2.enabled(): @@ -2436,7 +2436,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): except AttributeError: pass else: - self._dtype_policy = policy.Policy(dtype) + self._set_dtype_policy(policy.Policy(dtype)) input_shapes = None # Converts Tensors / CompositeTensors to TensorShapes. if all(hasattr(x, 'shape') for x in input_list): diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index ca138d79020..13fb2b28bb7 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -616,6 +616,14 @@ class BaseLayerTest(keras_parameterized.TestCase): self.assertTrue(layer.built) self.assertEqual([None, 3], layer._build_input_shape.as_list()) + def test_activity_regularizer_string(self): + + class MyLayer(base_layer.Layer): + pass + + layer = MyLayer(activity_regularizer='l2') + self.assertIsInstance(layer.activity_regularizer, regularizers.L2) + class SymbolicSupportTest(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 78140985b4a..725334f8535 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -158,12 +158,8 @@ class Layer(base_layer.Layer): # are only applicable to input layers: do not pass these keywords # to non-input layers. allowed_kwargs = { - 'input_shape', - 'batch_input_shape', - 'batch_size', - 'weights', - 'activity_regularizer', - 'autocast' + 'input_dim', 'input_shape', 'batch_input_shape', 'batch_size', + 'weights', 'activity_regularizer', 'autocast' } # Validate optional keyword arguments. generic_utils.validate_kwargs(kwargs, allowed_kwargs) @@ -184,7 +180,8 @@ class Layer(base_layer.Layer): self.supports_masking = False self._init_set_name(name) - self._activity_regularizer = kwargs.pop('activity_regularizer', None) + self._activity_regularizer = regularizers.get( + kwargs.pop('activity_regularizer', None)) self._maybe_create_attribute('_trainable_weights', []) self._maybe_create_attribute('_non_trainable_weights', []) self._updates = [] @@ -229,6 +226,9 @@ class Layer(base_layer.Layer): self._dynamic = dynamic # Manage input shape information if passed. + if 'input_dim' in kwargs and 'input_shape' not in kwargs: + # Backwards compatibility: alias 'input_dim' to 'input_shape'. + kwargs['input_shape'] = (kwargs['input_dim'],) if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: # In this case we will later create an input layer # to insert before the current layer @@ -378,7 +378,7 @@ class Layer(base_layer.Layer): dtype = dtypes.as_dtype(dtype) if self._dtype_policy.variable_dtype is None: # The policy is "_infer", so we infer the policy from the variable dtype. - self._dtype_policy = policy.Policy(dtype.base_dtype.name) + self._set_dtype_policy(policy.Policy(dtype.base_dtype.name)) initializer = initializers.get(initializer) regularizer = regularizers.get(regularizer) constraint = constraints.get(constraint) @@ -1835,7 +1835,7 @@ class Layer(base_layer.Layer): @_dtype.setter def _dtype(self, value): value = dtypes.as_dtype(value).name - self._dtype_policy = policy.Policy(value) + self._set_dtype_policy(policy.Policy(value)) def _name_scope(self): return self.name @@ -2068,7 +2068,7 @@ class Layer(base_layer.Layer): except AttributeError: pass else: - self._dtype_policy = policy.Policy(dtype) + self._set_dtype_policy(policy.Policy(dtype)) input_shapes = None if all(hasattr(x, 'shape') for x in input_list): input_shapes = nest.map_structure(lambda x: x.shape, inputs) diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index 10a9fe088ab..0b664d01b6a 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -141,6 +141,7 @@ py_library( "//tensorflow/python/keras:initializers", "//tensorflow/python/keras:regularizers", "//tensorflow/python/keras/engine:input_spec", + "//tensorflow/python/keras/layers/ops:core", "//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:generic_utils", "//tensorflow/python/keras/utils:tf_utils", diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 60834fad30b..56512d0d754 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -37,18 +37,15 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_spec import InputSpec +from tensorflow.python.keras.layers.ops import core as core_ops from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging from tensorflow.python.training.tracking import base as trackable @@ -1132,11 +1129,8 @@ class Dense(Layer): kernel_constraint=None, bias_constraint=None, **kwargs): - if 'input_shape' not in kwargs and 'input_dim' in kwargs: - kwargs['input_shape'] = (kwargs.pop('input_dim'),) - super(Dense, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + activity_regularizer=activity_regularizer, **kwargs) self.units = int(units) if not isinstance(units, int) else units self.activation = activations.get(activation) @@ -1148,19 +1142,20 @@ class Dense(Layer): self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) - self.supports_masking = True self.input_spec = InputSpec(min_ndim=2) + self.supports_masking = True def build(self, input_shape): dtype = dtypes.as_dtype(self.dtype or K.floatx()) if not (dtype.is_floating or dtype.is_complex): raise TypeError('Unable to build `Dense` layer with non-floating point ' 'dtype %s' % (dtype,)) + input_shape = tensor_shape.TensorShape(input_shape) - if tensor_shape.dimension_value(input_shape[-1]) is None: + last_dim = tensor_shape.dimension_value(input_shape[-1]) + if last_dim is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') - last_dim = tensor_shape.dimension_value(input_shape[-1]) self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim}) self.kernel = self.add_weight( 'kernel', @@ -1184,27 +1179,12 @@ class Dense(Layer): self.built = True def call(self, inputs): - base_layer_utils.no_ragged_support(inputs, self.name) - rank = inputs.shape.rank - if rank is not None and rank > 2: - # Broadcasting is required for the inputs. - outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]]) - # Reshape the output back to the original ndim of the input. - if not context.executing_eagerly(): - shape = inputs.shape.as_list() - output_shape = shape[:-1] + [self.units] - outputs.set_shape(output_shape) - else: - inputs = math_ops.cast(inputs, self._compute_dtype) - if K.is_sparse(inputs): - outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, self.kernel) - else: - outputs = gen_math_ops.mat_mul(inputs, self.kernel) - if self.use_bias: - outputs = nn.bias_add(outputs, self.bias) - if self.activation is not None: - return self.activation(outputs) # pylint: disable=not-callable - return outputs + return core_ops.dense( + inputs, + self.kernel, + self.bias, + self.activation, + dtype=self._compute_dtype_object) def compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) @@ -1216,21 +1196,30 @@ class Dense(Layer): return input_shape[:-1].concatenate(self.units) def get_config(self): - config = { - 'units': self.units, - 'activation': activations.serialize(self.activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + config = super(Dense, self).get_config() + config.update({ + 'units': + self.units, + 'activation': + activations.serialize(self.activation), + 'use_bias': + self.use_bias, + 'kernel_initializer': + initializers.serialize(self.kernel_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'kernel_regularizer': + regularizers.serialize(self.kernel_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint) - } - base_config = super(Dense, self).get_config() - return dict(list(base_config.items()) + list(config.items())) + 'kernel_constraint': + constraints.serialize(self.kernel_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint) + }) + return config @keras_export('keras.layers.ActivityRegularization') diff --git a/tensorflow/python/keras/layers/ops/BUILD b/tensorflow/python/keras/layers/ops/BUILD new file mode 100644 index 00000000000..09973c54790 --- /dev/null +++ b/tensorflow/python/keras/layers/ops/BUILD @@ -0,0 +1,24 @@ +# Description: +# Contains stateless ops for Keras layers. + +package( + default_visibility = [ + "//tensorflow/python/keras/layers:__pkg__", + ], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) + +py_library( + name = "core", + srcs = ["core.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:standard_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + ], +) diff --git a/tensorflow/python/keras/layers/ops/__init__.py b/tensorflow/python/keras/layers/ops/__init__.py new file mode 100644 index 00000000000..27d099a4898 --- /dev/null +++ b/tensorflow/python/keras/layers/ops/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stateless ops for Keras layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/python/keras/layers/ops/core.py b/tensorflow/python/keras/layers/ops/core.py new file mode 100644 index 00000000000..1a30472cba3 --- /dev/null +++ b/tensorflow/python/keras/layers/ops/core.py @@ -0,0 +1,69 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stateless ops for core Keras layers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import standard_ops + + +# TODO(b/157913406): Expose this publicly. +def dense(inputs, kernel, bias=None, activation=None, dtype=None): + """Densely connected NN layer op. + + Arguments: + inputs: `tf.Tensor` or `tf.SparseTensor`. Inputs to operation. + kernel: `tf.Variable`. Matrix kernel. + bias: (Optional) `tf.Variable`. Bias to add to outputs. + activation: (Optional) 1-argument callable. Activation function to apply to + outputs. + dtype: (Optional) `tf.DType`. Dtype to cast `inputs` to. + + Returns: + `tf.Tensor`. Output of dense connection. + """ + if dtype: + if inputs.dtype.base_dtype != dtype.base_dtype: + inputs = math_ops.cast(inputs, dtype=dtype) + + rank = inputs.shape.rank + if rank == 2 or rank is None: + if isinstance(inputs, sparse_tensor.SparseTensor): + outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel) + else: + outputs = gen_math_ops.mat_mul(inputs, kernel) + # Broadcast kernel to inputs. + else: + outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]]) + # Reshape the output back to the original ndim of the input. + if not context.executing_eagerly(): + shape = inputs.shape.as_list() + output_shape = shape[:-1] + [kernel.shape[-1]] + outputs.set_shape(output_shape) + + if bias is not None: + outputs = nn_ops.bias_add(outputs, bias) + + if activation is not None: + outputs = activation(outputs) + + return outputs