Reduce Dense.__call__ overhead by ~15%

- Creates a layers/ops directory to contain functional implementations of Keras
  layer ops (this should make it easier to XLA compile these ops, as well as potentially being useful to expose stateless versions of the layer ops)
- Moves Dense op to this directory and implements a functional version of
  Dense.call with reduced Python overhead
- Uses this op in Dense layer
- Cleans up input_dim and activity_regularizer handling in Dense and Layer
- Adds a microbenchmark for Dense

Also fixes two small issues:
- compute_dtype_object was not always correctly set in Layer
- activity_regularizer can now be passed as a str to any Layer class

PiperOrigin-RevId: 314256375
Change-Id: I769cef6d67aa117f6cb75dc34c0594e748284af2
This commit is contained in:
Thomas O'Malley 2020-06-01 20:19:00 -07:00 committed by TensorFlower Gardener
parent df7fd4acda
commit 70387ab55b
9 changed files with 189 additions and 66 deletions

View File

@ -53,6 +53,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.layers import core as core_layers
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops from tensorflow.python.ops import functional_ops
@ -1404,7 +1405,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
self._run(fn, 10000) self._run(fn, 10000)
# TODO(b/157587712): Move to keras when benchmarks are setup. # TODO(b/157587712): Move to keras when benchmarks are setup.
def benchmark_tf_keras_layer_call(self): def benchmark_tf_keras_layer_call_overhead(self):
class OnlyOverheadLayer(base_layer.Layer): class OnlyOverheadLayer(base_layer.Layer):
@ -1419,6 +1420,18 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
self._run(fn, 10000) self._run(fn, 10000)
# TODO(b/157587712): Move to keras when benchmarks are setup.
def benchmark_tf_keras_dense_overhead(self):
layer = core_layers.Dense(1)
x = ops.convert_to_tensor([[1.]])
layer(x) # Warmup call to `build` layer.
def fn():
layer(x)
self._run(fn, 10000)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -293,12 +293,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# are only applicable to input layers: do not pass these keywords # are only applicable to input layers: do not pass these keywords
# to non-input layers. # to non-input layers.
allowed_kwargs = { allowed_kwargs = {
'input_shape', 'input_dim', 'input_shape', 'batch_input_shape', 'batch_size',
'batch_input_shape', 'weights', 'activity_regularizer', 'autocast'
'batch_size',
'weights',
'activity_regularizer',
'autocast'
} }
# Validate optional keyword arguments. # Validate optional keyword arguments.
generic_utils.validate_kwargs(kwargs, allowed_kwargs) generic_utils.validate_kwargs(kwargs, allowed_kwargs)
@ -323,7 +319,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self.supports_masking = False self.supports_masking = False
self._init_set_name(name) self._init_set_name(name)
self._activity_regularizer = kwargs.pop('activity_regularizer', None) self._activity_regularizer = regularizers.get(
kwargs.pop('activity_regularizer', None))
self._maybe_create_attribute('_trainable_weights', []) self._maybe_create_attribute('_trainable_weights', [])
self._maybe_create_attribute('_non_trainable_weights', []) self._maybe_create_attribute('_non_trainable_weights', [])
self._updates = [] self._updates = []
@ -370,6 +367,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._dynamic = dynamic self._dynamic = dynamic
# Manage input shape information if passed. # Manage input shape information if passed.
if 'input_dim' in kwargs and 'input_shape' not in kwargs:
# Backwards compatibility: alias 'input_dim' to 'input_shape'.
kwargs['input_shape'] = (kwargs['input_dim'],)
if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
# In this case we will later create an input layer # In this case we will later create an input layer
# to insert before the current layer # to insert before the current layer
@ -530,7 +530,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
if self._dtype_policy.variable_dtype is None: if self._dtype_policy.variable_dtype is None:
# The policy is "_infer", so we infer the policy from the variable dtype. # The policy is "_infer", so we infer the policy from the variable dtype.
self._dtype_policy = policy.Policy(dtype.base_dtype.name) self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
initializer = initializers.get(initializer) initializer = initializers.get(initializer)
regularizer = regularizers.get(regularizer) regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint) constraint = constraints.get(constraint)
@ -2202,7 +2202,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
@_dtype.setter @_dtype.setter
def _dtype(self, value): def _dtype(self, value):
value = dtypes.as_dtype(value).name value = dtypes.as_dtype(value).name
self._dtype_policy = policy.Policy(value) self._set_dtype_policy(policy.Policy(value))
def _name_scope(self): def _name_scope(self):
if not tf2.enabled(): if not tf2.enabled():
@ -2436,7 +2436,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
except AttributeError: except AttributeError:
pass pass
else: else:
self._dtype_policy = policy.Policy(dtype) self._set_dtype_policy(policy.Policy(dtype))
input_shapes = None input_shapes = None
# Converts Tensors / CompositeTensors to TensorShapes. # Converts Tensors / CompositeTensors to TensorShapes.
if all(hasattr(x, 'shape') for x in input_list): if all(hasattr(x, 'shape') for x in input_list):

View File

@ -616,6 +616,14 @@ class BaseLayerTest(keras_parameterized.TestCase):
self.assertTrue(layer.built) self.assertTrue(layer.built)
self.assertEqual([None, 3], layer._build_input_shape.as_list()) self.assertEqual([None, 3], layer._build_input_shape.as_list())
def test_activity_regularizer_string(self):
class MyLayer(base_layer.Layer):
pass
layer = MyLayer(activity_regularizer='l2')
self.assertIsInstance(layer.activity_regularizer, regularizers.L2)
class SymbolicSupportTest(keras_parameterized.TestCase): class SymbolicSupportTest(keras_parameterized.TestCase):

View File

@ -158,12 +158,8 @@ class Layer(base_layer.Layer):
# are only applicable to input layers: do not pass these keywords # are only applicable to input layers: do not pass these keywords
# to non-input layers. # to non-input layers.
allowed_kwargs = { allowed_kwargs = {
'input_shape', 'input_dim', 'input_shape', 'batch_input_shape', 'batch_size',
'batch_input_shape', 'weights', 'activity_regularizer', 'autocast'
'batch_size',
'weights',
'activity_regularizer',
'autocast'
} }
# Validate optional keyword arguments. # Validate optional keyword arguments.
generic_utils.validate_kwargs(kwargs, allowed_kwargs) generic_utils.validate_kwargs(kwargs, allowed_kwargs)
@ -184,7 +180,8 @@ class Layer(base_layer.Layer):
self.supports_masking = False self.supports_masking = False
self._init_set_name(name) self._init_set_name(name)
self._activity_regularizer = kwargs.pop('activity_regularizer', None) self._activity_regularizer = regularizers.get(
kwargs.pop('activity_regularizer', None))
self._maybe_create_attribute('_trainable_weights', []) self._maybe_create_attribute('_trainable_weights', [])
self._maybe_create_attribute('_non_trainable_weights', []) self._maybe_create_attribute('_non_trainable_weights', [])
self._updates = [] self._updates = []
@ -229,6 +226,9 @@ class Layer(base_layer.Layer):
self._dynamic = dynamic self._dynamic = dynamic
# Manage input shape information if passed. # Manage input shape information if passed.
if 'input_dim' in kwargs and 'input_shape' not in kwargs:
# Backwards compatibility: alias 'input_dim' to 'input_shape'.
kwargs['input_shape'] = (kwargs['input_dim'],)
if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
# In this case we will later create an input layer # In this case we will later create an input layer
# to insert before the current layer # to insert before the current layer
@ -378,7 +378,7 @@ class Layer(base_layer.Layer):
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
if self._dtype_policy.variable_dtype is None: if self._dtype_policy.variable_dtype is None:
# The policy is "_infer", so we infer the policy from the variable dtype. # The policy is "_infer", so we infer the policy from the variable dtype.
self._dtype_policy = policy.Policy(dtype.base_dtype.name) self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
initializer = initializers.get(initializer) initializer = initializers.get(initializer)
regularizer = regularizers.get(regularizer) regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint) constraint = constraints.get(constraint)
@ -1835,7 +1835,7 @@ class Layer(base_layer.Layer):
@_dtype.setter @_dtype.setter
def _dtype(self, value): def _dtype(self, value):
value = dtypes.as_dtype(value).name value = dtypes.as_dtype(value).name
self._dtype_policy = policy.Policy(value) self._set_dtype_policy(policy.Policy(value))
def _name_scope(self): def _name_scope(self):
return self.name return self.name
@ -2068,7 +2068,7 @@ class Layer(base_layer.Layer):
except AttributeError: except AttributeError:
pass pass
else: else:
self._dtype_policy = policy.Policy(dtype) self._set_dtype_policy(policy.Policy(dtype))
input_shapes = None input_shapes = None
if all(hasattr(x, 'shape') for x in input_list): if all(hasattr(x, 'shape') for x in input_list):
input_shapes = nest.map_structure(lambda x: x.shape, inputs) input_shapes = nest.map_structure(lambda x: x.shape, inputs)

View File

@ -141,6 +141,7 @@ py_library(
"//tensorflow/python/keras:initializers", "//tensorflow/python/keras:initializers",
"//tensorflow/python/keras:regularizers", "//tensorflow/python/keras:regularizers",
"//tensorflow/python/keras/engine:input_spec", "//tensorflow/python/keras/engine:input_spec",
"//tensorflow/python/keras/layers/ops:core",
"//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:engine_utils",
"//tensorflow/python/keras/utils:generic_utils", "//tensorflow/python/keras/utils:generic_utils",
"//tensorflow/python/keras/utils:tf_utils", "//tensorflow/python/keras/utils:tf_utils",

View File

@ -37,18 +37,15 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.layers.ops import core as core_ops
from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
@ -1132,11 +1129,8 @@ class Dense(Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super(Dense, self).__init__( super(Dense, self).__init__(
activity_regularizer=regularizers.get(activity_regularizer), **kwargs) activity_regularizer=activity_regularizer, **kwargs)
self.units = int(units) if not isinstance(units, int) else units self.units = int(units) if not isinstance(units, int) else units
self.activation = activations.get(activation) self.activation = activations.get(activation)
@ -1148,19 +1142,20 @@ class Dense(Layer):
self.kernel_constraint = constraints.get(kernel_constraint) self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint) self.bias_constraint = constraints.get(bias_constraint)
self.supports_masking = True
self.input_spec = InputSpec(min_ndim=2) self.input_spec = InputSpec(min_ndim=2)
self.supports_masking = True
def build(self, input_shape): def build(self, input_shape):
dtype = dtypes.as_dtype(self.dtype or K.floatx()) dtype = dtypes.as_dtype(self.dtype or K.floatx())
if not (dtype.is_floating or dtype.is_complex): if not (dtype.is_floating or dtype.is_complex):
raise TypeError('Unable to build `Dense` layer with non-floating point ' raise TypeError('Unable to build `Dense` layer with non-floating point '
'dtype %s' % (dtype,)) 'dtype %s' % (dtype,))
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
if tensor_shape.dimension_value(input_shape[-1]) is None: last_dim = tensor_shape.dimension_value(input_shape[-1])
if last_dim is None:
raise ValueError('The last dimension of the inputs to `Dense` ' raise ValueError('The last dimension of the inputs to `Dense` '
'should be defined. Found `None`.') 'should be defined. Found `None`.')
last_dim = tensor_shape.dimension_value(input_shape[-1])
self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim}) self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim})
self.kernel = self.add_weight( self.kernel = self.add_weight(
'kernel', 'kernel',
@ -1184,27 +1179,12 @@ class Dense(Layer):
self.built = True self.built = True
def call(self, inputs): def call(self, inputs):
base_layer_utils.no_ragged_support(inputs, self.name) return core_ops.dense(
rank = inputs.shape.rank inputs,
if rank is not None and rank > 2: self.kernel,
# Broadcasting is required for the inputs. self.bias,
outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]]) self.activation,
# Reshape the output back to the original ndim of the input. dtype=self._compute_dtype_object)
if not context.executing_eagerly():
shape = inputs.shape.as_list()
output_shape = shape[:-1] + [self.units]
outputs.set_shape(output_shape)
else:
inputs = math_ops.cast(inputs, self._compute_dtype)
if K.is_sparse(inputs):
outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, self.kernel)
else:
outputs = gen_math_ops.mat_mul(inputs, self.kernel)
if self.use_bias:
outputs = nn.bias_add(outputs, self.bias)
if self.activation is not None:
return self.activation(outputs) # pylint: disable=not-callable
return outputs
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape) input_shape = tensor_shape.TensorShape(input_shape)
@ -1216,21 +1196,30 @@ class Dense(Layer):
return input_shape[:-1].concatenate(self.units) return input_shape[:-1].concatenate(self.units)
def get_config(self): def get_config(self):
config = { config = super(Dense, self).get_config()
'units': self.units, config.update({
'activation': activations.serialize(self.activation), 'units':
'use_bias': self.use_bias, self.units,
'kernel_initializer': initializers.serialize(self.kernel_initializer), 'activation':
'bias_initializer': initializers.serialize(self.bias_initializer), activations.serialize(self.activation),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'use_bias':
'bias_regularizer': regularizers.serialize(self.bias_regularizer), self.use_bias,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'activity_regularizer': 'activity_regularizer':
regularizers.serialize(self.activity_regularizer), regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint), 'kernel_constraint':
'bias_constraint': constraints.serialize(self.bias_constraint) constraints.serialize(self.kernel_constraint),
} 'bias_constraint':
base_config = super(Dense, self).get_config() constraints.serialize(self.bias_constraint)
return dict(list(base_config.items()) + list(config.items())) })
return config
@keras_export('keras.layers.ActivityRegularization') @keras_export('keras.layers.ActivityRegularization')

View File

@ -0,0 +1,24 @@
# Description:
# Contains stateless ops for Keras layers.
package(
default_visibility = [
"//tensorflow/python/keras/layers:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
py_library(
name = "core",
srcs = ["core.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:standard_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
],
)

View File

@ -0,0 +1,19 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Stateless ops for Keras layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,69 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Stateless ops for core Keras layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import standard_ops
# TODO(b/157913406): Expose this publicly.
def dense(inputs, kernel, bias=None, activation=None, dtype=None):
"""Densely connected NN layer op.
Arguments:
inputs: `tf.Tensor` or `tf.SparseTensor`. Inputs to operation.
kernel: `tf.Variable`. Matrix kernel.
bias: (Optional) `tf.Variable`. Bias to add to outputs.
activation: (Optional) 1-argument callable. Activation function to apply to
outputs.
dtype: (Optional) `tf.DType`. Dtype to cast `inputs` to.
Returns:
`tf.Tensor`. Output of dense connection.
"""
if dtype:
if inputs.dtype.base_dtype != dtype.base_dtype:
inputs = math_ops.cast(inputs, dtype=dtype)
rank = inputs.shape.rank
if rank == 2 or rank is None:
if isinstance(inputs, sparse_tensor.SparseTensor):
outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel)
else:
outputs = gen_math_ops.mat_mul(inputs, kernel)
# Broadcast kernel to inputs.
else:
outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]])
# Reshape the output back to the original ndim of the input.
if not context.executing_eagerly():
shape = inputs.shape.as_list()
output_shape = shape[:-1] + [kernel.shape[-1]]
outputs.set_shape(output_shape)
if bias is not None:
outputs = nn_ops.bias_add(outputs, bias)
if activation is not None:
outputs = activation(outputs)
return outputs