Reduce Dense.__call__ overhead by ~15%
- Creates a layers/ops directory to contain functional implementations of Keras layer ops (this should make it easier to XLA compile these ops, as well as potentially being useful to expose stateless versions of the layer ops) - Moves Dense op to this directory and implements a functional version of Dense.call with reduced Python overhead - Uses this op in Dense layer - Cleans up input_dim and activity_regularizer handling in Dense and Layer - Adds a microbenchmark for Dense Also fixes two small issues: - compute_dtype_object was not always correctly set in Layer - activity_regularizer can now be passed as a str to any Layer class PiperOrigin-RevId: 314256375 Change-Id: I769cef6d67aa117f6cb75dc34c0594e748284af2
This commit is contained in:
parent
df7fd4acda
commit
70387ab55b
|
@ -53,6 +53,7 @@ from tensorflow.python.framework import ops
|
|||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.layers import core as core_layers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
|
@ -1404,7 +1405,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||
self._run(fn, 10000)
|
||||
|
||||
# TODO(b/157587712): Move to keras when benchmarks are setup.
|
||||
def benchmark_tf_keras_layer_call(self):
|
||||
def benchmark_tf_keras_layer_call_overhead(self):
|
||||
|
||||
class OnlyOverheadLayer(base_layer.Layer):
|
||||
|
||||
|
@ -1419,6 +1420,18 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
|||
|
||||
self._run(fn, 10000)
|
||||
|
||||
# TODO(b/157587712): Move to keras when benchmarks are setup.
|
||||
def benchmark_tf_keras_dense_overhead(self):
|
||||
|
||||
layer = core_layers.Dense(1)
|
||||
x = ops.convert_to_tensor([[1.]])
|
||||
layer(x) # Warmup call to `build` layer.
|
||||
|
||||
def fn():
|
||||
layer(x)
|
||||
|
||||
self._run(fn, 10000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
@ -293,12 +293,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
# are only applicable to input layers: do not pass these keywords
|
||||
# to non-input layers.
|
||||
allowed_kwargs = {
|
||||
'input_shape',
|
||||
'batch_input_shape',
|
||||
'batch_size',
|
||||
'weights',
|
||||
'activity_regularizer',
|
||||
'autocast'
|
||||
'input_dim', 'input_shape', 'batch_input_shape', 'batch_size',
|
||||
'weights', 'activity_regularizer', 'autocast'
|
||||
}
|
||||
# Validate optional keyword arguments.
|
||||
generic_utils.validate_kwargs(kwargs, allowed_kwargs)
|
||||
|
@ -323,7 +319,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
self.supports_masking = False
|
||||
|
||||
self._init_set_name(name)
|
||||
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
|
||||
self._activity_regularizer = regularizers.get(
|
||||
kwargs.pop('activity_regularizer', None))
|
||||
self._maybe_create_attribute('_trainable_weights', [])
|
||||
self._maybe_create_attribute('_non_trainable_weights', [])
|
||||
self._updates = []
|
||||
|
@ -370,6 +367,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
self._dynamic = dynamic
|
||||
|
||||
# Manage input shape information if passed.
|
||||
if 'input_dim' in kwargs and 'input_shape' not in kwargs:
|
||||
# Backwards compatibility: alias 'input_dim' to 'input_shape'.
|
||||
kwargs['input_shape'] = (kwargs['input_dim'],)
|
||||
if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
|
||||
# In this case we will later create an input layer
|
||||
# to insert before the current layer
|
||||
|
@ -530,7 +530,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
dtype = dtypes.as_dtype(dtype)
|
||||
if self._dtype_policy.variable_dtype is None:
|
||||
# The policy is "_infer", so we infer the policy from the variable dtype.
|
||||
self._dtype_policy = policy.Policy(dtype.base_dtype.name)
|
||||
self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
|
||||
initializer = initializers.get(initializer)
|
||||
regularizer = regularizers.get(regularizer)
|
||||
constraint = constraints.get(constraint)
|
||||
|
@ -2202,7 +2202,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
@_dtype.setter
|
||||
def _dtype(self, value):
|
||||
value = dtypes.as_dtype(value).name
|
||||
self._dtype_policy = policy.Policy(value)
|
||||
self._set_dtype_policy(policy.Policy(value))
|
||||
|
||||
def _name_scope(self):
|
||||
if not tf2.enabled():
|
||||
|
@ -2436,7 +2436,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._dtype_policy = policy.Policy(dtype)
|
||||
self._set_dtype_policy(policy.Policy(dtype))
|
||||
input_shapes = None
|
||||
# Converts Tensors / CompositeTensors to TensorShapes.
|
||||
if all(hasattr(x, 'shape') for x in input_list):
|
||||
|
|
|
@ -616,6 +616,14 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||
self.assertTrue(layer.built)
|
||||
self.assertEqual([None, 3], layer._build_input_shape.as_list())
|
||||
|
||||
def test_activity_regularizer_string(self):
|
||||
|
||||
class MyLayer(base_layer.Layer):
|
||||
pass
|
||||
|
||||
layer = MyLayer(activity_regularizer='l2')
|
||||
self.assertIsInstance(layer.activity_regularizer, regularizers.L2)
|
||||
|
||||
|
||||
class SymbolicSupportTest(keras_parameterized.TestCase):
|
||||
|
||||
|
|
|
@ -158,12 +158,8 @@ class Layer(base_layer.Layer):
|
|||
# are only applicable to input layers: do not pass these keywords
|
||||
# to non-input layers.
|
||||
allowed_kwargs = {
|
||||
'input_shape',
|
||||
'batch_input_shape',
|
||||
'batch_size',
|
||||
'weights',
|
||||
'activity_regularizer',
|
||||
'autocast'
|
||||
'input_dim', 'input_shape', 'batch_input_shape', 'batch_size',
|
||||
'weights', 'activity_regularizer', 'autocast'
|
||||
}
|
||||
# Validate optional keyword arguments.
|
||||
generic_utils.validate_kwargs(kwargs, allowed_kwargs)
|
||||
|
@ -184,7 +180,8 @@ class Layer(base_layer.Layer):
|
|||
self.supports_masking = False
|
||||
|
||||
self._init_set_name(name)
|
||||
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
|
||||
self._activity_regularizer = regularizers.get(
|
||||
kwargs.pop('activity_regularizer', None))
|
||||
self._maybe_create_attribute('_trainable_weights', [])
|
||||
self._maybe_create_attribute('_non_trainable_weights', [])
|
||||
self._updates = []
|
||||
|
@ -229,6 +226,9 @@ class Layer(base_layer.Layer):
|
|||
self._dynamic = dynamic
|
||||
|
||||
# Manage input shape information if passed.
|
||||
if 'input_dim' in kwargs and 'input_shape' not in kwargs:
|
||||
# Backwards compatibility: alias 'input_dim' to 'input_shape'.
|
||||
kwargs['input_shape'] = (kwargs['input_dim'],)
|
||||
if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
|
||||
# In this case we will later create an input layer
|
||||
# to insert before the current layer
|
||||
|
@ -378,7 +378,7 @@ class Layer(base_layer.Layer):
|
|||
dtype = dtypes.as_dtype(dtype)
|
||||
if self._dtype_policy.variable_dtype is None:
|
||||
# The policy is "_infer", so we infer the policy from the variable dtype.
|
||||
self._dtype_policy = policy.Policy(dtype.base_dtype.name)
|
||||
self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
|
||||
initializer = initializers.get(initializer)
|
||||
regularizer = regularizers.get(regularizer)
|
||||
constraint = constraints.get(constraint)
|
||||
|
@ -1835,7 +1835,7 @@ class Layer(base_layer.Layer):
|
|||
@_dtype.setter
|
||||
def _dtype(self, value):
|
||||
value = dtypes.as_dtype(value).name
|
||||
self._dtype_policy = policy.Policy(value)
|
||||
self._set_dtype_policy(policy.Policy(value))
|
||||
|
||||
def _name_scope(self):
|
||||
return self.name
|
||||
|
@ -2068,7 +2068,7 @@ class Layer(base_layer.Layer):
|
|||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._dtype_policy = policy.Policy(dtype)
|
||||
self._set_dtype_policy(policy.Policy(dtype))
|
||||
input_shapes = None
|
||||
if all(hasattr(x, 'shape') for x in input_list):
|
||||
input_shapes = nest.map_structure(lambda x: x.shape, inputs)
|
||||
|
|
|
@ -141,6 +141,7 @@ py_library(
|
|||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras:regularizers",
|
||||
"//tensorflow/python/keras/engine:input_spec",
|
||||
"//tensorflow/python/keras/layers/ops:core",
|
||||
"//tensorflow/python/keras/utils:engine_utils",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
|
|
|
@ -37,18 +37,15 @@ from tensorflow.python.keras import backend as K
|
|||
from tensorflow.python.keras import constraints
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.engine.input_spec import InputSpec
|
||||
from tensorflow.python.keras.layers.ops import core as core_ops
|
||||
from tensorflow.python.keras.utils import conv_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import standard_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
|
@ -1132,11 +1129,8 @@ class Dense(Layer):
|
|||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
|
||||
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
|
||||
|
||||
super(Dense, self).__init__(
|
||||
activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
|
||||
activity_regularizer=activity_regularizer, **kwargs)
|
||||
|
||||
self.units = int(units) if not isinstance(units, int) else units
|
||||
self.activation = activations.get(activation)
|
||||
|
@ -1148,19 +1142,20 @@ class Dense(Layer):
|
|||
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
|
||||
self.supports_masking = True
|
||||
self.input_spec = InputSpec(min_ndim=2)
|
||||
self.supports_masking = True
|
||||
|
||||
def build(self, input_shape):
|
||||
dtype = dtypes.as_dtype(self.dtype or K.floatx())
|
||||
if not (dtype.is_floating or dtype.is_complex):
|
||||
raise TypeError('Unable to build `Dense` layer with non-floating point '
|
||||
'dtype %s' % (dtype,))
|
||||
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
if tensor_shape.dimension_value(input_shape[-1]) is None:
|
||||
last_dim = tensor_shape.dimension_value(input_shape[-1])
|
||||
if last_dim is None:
|
||||
raise ValueError('The last dimension of the inputs to `Dense` '
|
||||
'should be defined. Found `None`.')
|
||||
last_dim = tensor_shape.dimension_value(input_shape[-1])
|
||||
self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim})
|
||||
self.kernel = self.add_weight(
|
||||
'kernel',
|
||||
|
@ -1184,27 +1179,12 @@ class Dense(Layer):
|
|||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
base_layer_utils.no_ragged_support(inputs, self.name)
|
||||
rank = inputs.shape.rank
|
||||
if rank is not None and rank > 2:
|
||||
# Broadcasting is required for the inputs.
|
||||
outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]])
|
||||
# Reshape the output back to the original ndim of the input.
|
||||
if not context.executing_eagerly():
|
||||
shape = inputs.shape.as_list()
|
||||
output_shape = shape[:-1] + [self.units]
|
||||
outputs.set_shape(output_shape)
|
||||
else:
|
||||
inputs = math_ops.cast(inputs, self._compute_dtype)
|
||||
if K.is_sparse(inputs):
|
||||
outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, self.kernel)
|
||||
else:
|
||||
outputs = gen_math_ops.mat_mul(inputs, self.kernel)
|
||||
if self.use_bias:
|
||||
outputs = nn.bias_add(outputs, self.bias)
|
||||
if self.activation is not None:
|
||||
return self.activation(outputs) # pylint: disable=not-callable
|
||||
return outputs
|
||||
return core_ops.dense(
|
||||
inputs,
|
||||
self.kernel,
|
||||
self.bias,
|
||||
self.activation,
|
||||
dtype=self._compute_dtype_object)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
|
@ -1216,21 +1196,30 @@ class Dense(Layer):
|
|||
return input_shape[:-1].concatenate(self.units)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'units': self.units,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'kernel_initializer': initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer': initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
||||
config = super(Dense, self).get_config()
|
||||
config.update({
|
||||
'units':
|
||||
self.units,
|
||||
'activation':
|
||||
activations.serialize(self.activation),
|
||||
'use_bias':
|
||||
self.use_bias,
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer':
|
||||
initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer':
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer':
|
||||
regularizers.serialize(self.bias_regularizer),
|
||||
'activity_regularizer':
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
'kernel_constraint': constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint': constraints.serialize(self.bias_constraint)
|
||||
}
|
||||
base_config = super(Dense, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
'kernel_constraint':
|
||||
constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint':
|
||||
constraints.serialize(self.bias_constraint)
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
@keras_export('keras.layers.ActivityRegularization')
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Description:
|
||||
# Contains stateless ops for Keras layers.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/python/keras/layers:__pkg__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "core",
|
||||
srcs = ["core.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:standard_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Stateless ops for Keras layers."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Stateless ops for core Keras layers."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import standard_ops
|
||||
|
||||
|
||||
# TODO(b/157913406): Expose this publicly.
|
||||
def dense(inputs, kernel, bias=None, activation=None, dtype=None):
|
||||
"""Densely connected NN layer op.
|
||||
|
||||
Arguments:
|
||||
inputs: `tf.Tensor` or `tf.SparseTensor`. Inputs to operation.
|
||||
kernel: `tf.Variable`. Matrix kernel.
|
||||
bias: (Optional) `tf.Variable`. Bias to add to outputs.
|
||||
activation: (Optional) 1-argument callable. Activation function to apply to
|
||||
outputs.
|
||||
dtype: (Optional) `tf.DType`. Dtype to cast `inputs` to.
|
||||
|
||||
Returns:
|
||||
`tf.Tensor`. Output of dense connection.
|
||||
"""
|
||||
if dtype:
|
||||
if inputs.dtype.base_dtype != dtype.base_dtype:
|
||||
inputs = math_ops.cast(inputs, dtype=dtype)
|
||||
|
||||
rank = inputs.shape.rank
|
||||
if rank == 2 or rank is None:
|
||||
if isinstance(inputs, sparse_tensor.SparseTensor):
|
||||
outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel)
|
||||
else:
|
||||
outputs = gen_math_ops.mat_mul(inputs, kernel)
|
||||
# Broadcast kernel to inputs.
|
||||
else:
|
||||
outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]])
|
||||
# Reshape the output back to the original ndim of the input.
|
||||
if not context.executing_eagerly():
|
||||
shape = inputs.shape.as_list()
|
||||
output_shape = shape[:-1] + [kernel.shape[-1]]
|
||||
outputs.set_shape(output_shape)
|
||||
|
||||
if bias is not None:
|
||||
outputs = nn_ops.bias_add(outputs, bias)
|
||||
|
||||
if activation is not None:
|
||||
outputs = activation(outputs)
|
||||
|
||||
return outputs
|
Loading…
Reference in New Issue