Fix keras saving bug in multiheadattention.
Remember the input shapes when build_from_signature is called and trigger the build in from_config. PiperOrigin-RevId: 350795939 Change-Id: I8e04f510e41cb9c63854ac79e1778a12cbe53eea
This commit is contained in:
parent
40eec99c04
commit
7595aeda0b
tensorflow/python/keras/layers
@ -219,6 +219,7 @@ py_library(
|
||||
srcs = ["multi_head_attention.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:special_math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -245,6 +246,7 @@ class MultiHeadAttention(Layer):
|
||||
else:
|
||||
self._attention_axes = attention_axes
|
||||
self._built_from_signature = False
|
||||
self._query_shape, self._key_shape, self._value_shape = None, None, None
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
@ -275,11 +277,32 @@ class MultiHeadAttention(Layer):
|
||||
"kernel_constraint":
|
||||
constraints.serialize(self._kernel_constraint),
|
||||
"bias_constraint":
|
||||
constraints.serialize(self._bias_constraint)
|
||||
constraints.serialize(self._bias_constraint),
|
||||
"query_shape": self._query_shape,
|
||||
"key_shape": self._key_shape,
|
||||
"value_shape": self._value_shape,
|
||||
}
|
||||
base_config = super(MultiHeadAttention, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
# If the layer has a different build() function from the Keras default,
|
||||
# we need to trigger the customized build to create weights.
|
||||
query_shape = config.pop("query_shape")
|
||||
key_shape = config.pop("key_shape")
|
||||
value_shape = config.pop("value_shape")
|
||||
layer = cls(**config)
|
||||
if None in [query_shape, key_shape, value_shape]:
|
||||
logging.warning(
|
||||
"One of the input shape is missing. They should be "
|
||||
"memorized when the layer was serialized. "
|
||||
"%s is created without weights.",
|
||||
str(cls))
|
||||
else:
|
||||
layer._build_from_signature(query_shape, value_shape, key_shape) # pylint: disable=protected-access
|
||||
return layer
|
||||
|
||||
def _build_from_signature(self, query, value, key=None):
|
||||
"""Builds layers and variables.
|
||||
|
||||
@ -292,19 +315,19 @@ class MultiHeadAttention(Layer):
|
||||
"""
|
||||
self._built_from_signature = True
|
||||
if hasattr(query, "shape"):
|
||||
query_shape = tensor_shape.TensorShape(query.shape)
|
||||
self._query_shape = tensor_shape.TensorShape(query.shape)
|
||||
else:
|
||||
query_shape = query
|
||||
self._query_shape = tensor_shape.TensorShape(query)
|
||||
if hasattr(value, "shape"):
|
||||
value_shape = tensor_shape.TensorShape(value.shape)
|
||||
self._value_shape = tensor_shape.TensorShape(value.shape)
|
||||
else:
|
||||
value_shape = value
|
||||
self._value_shape = tensor_shape.TensorShape(value)
|
||||
if key is None:
|
||||
key_shape = value_shape
|
||||
self._key_shape = self._value_shape
|
||||
elif hasattr(key, "shape"):
|
||||
key_shape = tensor_shape.TensorShape(key.shape)
|
||||
self._key_shape = tensor_shape.TensorShape(key.shape)
|
||||
else:
|
||||
key_shape = key
|
||||
self._key_shape = tensor_shape.TensorShape(key)
|
||||
|
||||
common_kwargs = dict(
|
||||
kernel_initializer=self._kernel_initializer,
|
||||
@ -318,7 +341,7 @@ class MultiHeadAttention(Layer):
|
||||
# to avoid creating symbolic Tensors that will later pollute any eager
|
||||
# operations.
|
||||
with tf_utils.maybe_init_scope(self):
|
||||
free_dims = query_shape.rank - 1
|
||||
free_dims = self._query_shape.rank - 1
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
free_dims, bound_dims=1, output_dims=2)
|
||||
self._query_dense = einsum_dense.EinsumDense(
|
||||
@ -329,7 +352,7 @@ class MultiHeadAttention(Layer):
|
||||
name="query",
|
||||
**common_kwargs)
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
key_shape.rank - 1, bound_dims=1, output_dims=2)
|
||||
self._key_shape.rank - 1, bound_dims=1, output_dims=2)
|
||||
self._key_dense = einsum_dense.EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(output_rank - 1,
|
||||
@ -338,7 +361,7 @@ class MultiHeadAttention(Layer):
|
||||
name="key",
|
||||
**common_kwargs)
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
value_shape.rank - 1, bound_dims=1, output_dims=2)
|
||||
self._value_shape.rank - 1, bound_dims=1, output_dims=2)
|
||||
self._value_dense = einsum_dense.EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(output_rank - 1,
|
||||
@ -357,7 +380,7 @@ class MultiHeadAttention(Layer):
|
||||
else:
|
||||
output_shape = self._output_shape
|
||||
else:
|
||||
output_shape = [query_shape[-1]]
|
||||
output_shape = [self._query_shape[-1]]
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
free_dims, bound_dims=2, output_dims=len(output_shape))
|
||||
self._output_dense = einsum_dense.EinsumDense(
|
||||
|
@ -268,5 +268,71 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
||||
|
||||
|
||||
class TestModel(keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__()
|
||||
self.attention = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=3,
|
||||
key_dim=4,
|
||||
value_dim=4,
|
||||
use_bias=True,
|
||||
dropout=0.0,
|
||||
output_shape=[12])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
def get_config(self):
|
||||
return {}
|
||||
|
||||
def call(self, x, training=False):
|
||||
return self.attention(x, x, training=training)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class KerasModelSavingTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_keras_saving_subclass(self):
|
||||
model = TestModel()
|
||||
query = keras.Input(shape=(40, 80))
|
||||
_ = model(query)
|
||||
model_path = self.get_temp_dir() + "/tmp_model"
|
||||
keras.models.save_model(model, model_path, save_format="tf")
|
||||
reloaded_model = keras.models.load_model(model_path)
|
||||
self.assertEqual(
|
||||
len(model.trainable_variables), len(reloaded_model.trainable_variables))
|
||||
for src_v, loaded_v in zip(model.trainable_variables,
|
||||
reloaded_model.trainable_variables):
|
||||
self.assertAllEqual(src_v, loaded_v)
|
||||
|
||||
@parameterized.parameters("h5", "tf")
|
||||
def test_keras_saving_functional(self, save_format):
|
||||
model = TestModel()
|
||||
query = keras.Input(shape=(40, 80))
|
||||
output = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=3,
|
||||
key_dim=4,
|
||||
value_dim=4,
|
||||
use_bias=True,
|
||||
dropout=0.0)(query, query)
|
||||
model = keras.Model(inputs=query, outputs=output)
|
||||
model_path = self.get_temp_dir() + "/tmp_model"
|
||||
keras.models.save_model(model, model_path, save_format=save_format)
|
||||
reloaded_model = keras.models.load_model(model_path)
|
||||
self.assertEqual(
|
||||
len(model.trainable_variables), len(reloaded_model.trainable_variables))
|
||||
for src_v, loaded_v in zip(model.trainable_variables,
|
||||
reloaded_model.trainable_variables):
|
||||
self.assertAllEqual(src_v, loaded_v)
|
||||
|
||||
def test_create_without_build(self):
|
||||
not_intialized_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=3, key_dim=4, value_dim=4)
|
||||
multi_head_attention.MultiHeadAttention.from_config(
|
||||
not_intialized_layer.get_config())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user