diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index de295f1466a..de071893704 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -219,6 +219,7 @@ py_library( srcs = ["multi_head_attention.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:platform", "//tensorflow/python:special_math_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", diff --git a/tensorflow/python/keras/layers/multi_head_attention.py b/tensorflow/python/keras/layers/multi_head_attention.py index 164a5f0b9a7..3f7ff856bc0 100644 --- a/tensorflow/python/keras/layers/multi_head_attention.py +++ b/tensorflow/python/keras/layers/multi_head_attention.py @@ -37,6 +37,7 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import special_math_ops +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import keras_export @@ -245,6 +246,7 @@ class MultiHeadAttention(Layer): else: self._attention_axes = attention_axes self._built_from_signature = False + self._query_shape, self._key_shape, self._value_shape = None, None, None def get_config(self): config = { @@ -275,11 +277,32 @@ class MultiHeadAttention(Layer): "kernel_constraint": constraints.serialize(self._kernel_constraint), "bias_constraint": - constraints.serialize(self._bias_constraint) + constraints.serialize(self._bias_constraint), + "query_shape": self._query_shape, + "key_shape": self._key_shape, + "value_shape": self._value_shape, } base_config = super(MultiHeadAttention, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod + def from_config(cls, config): + # If the layer has a different build() function from the Keras default, + # we need to trigger the customized build to create weights. + query_shape = config.pop("query_shape") + key_shape = config.pop("key_shape") + value_shape = config.pop("value_shape") + layer = cls(**config) + if None in [query_shape, key_shape, value_shape]: + logging.warning( + "One of the input shape is missing. They should be " + "memorized when the layer was serialized. " + "%s is created without weights.", + str(cls)) + else: + layer._build_from_signature(query_shape, value_shape, key_shape) # pylint: disable=protected-access + return layer + def _build_from_signature(self, query, value, key=None): """Builds layers and variables. @@ -292,19 +315,19 @@ class MultiHeadAttention(Layer): """ self._built_from_signature = True if hasattr(query, "shape"): - query_shape = tensor_shape.TensorShape(query.shape) + self._query_shape = tensor_shape.TensorShape(query.shape) else: - query_shape = query + self._query_shape = tensor_shape.TensorShape(query) if hasattr(value, "shape"): - value_shape = tensor_shape.TensorShape(value.shape) + self._value_shape = tensor_shape.TensorShape(value.shape) else: - value_shape = value + self._value_shape = tensor_shape.TensorShape(value) if key is None: - key_shape = value_shape + self._key_shape = self._value_shape elif hasattr(key, "shape"): - key_shape = tensor_shape.TensorShape(key.shape) + self._key_shape = tensor_shape.TensorShape(key.shape) else: - key_shape = key + self._key_shape = tensor_shape.TensorShape(key) common_kwargs = dict( kernel_initializer=self._kernel_initializer, @@ -318,7 +341,7 @@ class MultiHeadAttention(Layer): # to avoid creating symbolic Tensors that will later pollute any eager # operations. with tf_utils.maybe_init_scope(self): - free_dims = query_shape.rank - 1 + free_dims = self._query_shape.rank - 1 einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=1, output_dims=2) self._query_dense = einsum_dense.EinsumDense( @@ -329,7 +352,7 @@ class MultiHeadAttention(Layer): name="query", **common_kwargs) einsum_equation, bias_axes, output_rank = _build_proj_equation( - key_shape.rank - 1, bound_dims=1, output_dims=2) + self._key_shape.rank - 1, bound_dims=1, output_dims=2) self._key_dense = einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape(output_rank - 1, @@ -338,7 +361,7 @@ class MultiHeadAttention(Layer): name="key", **common_kwargs) einsum_equation, bias_axes, output_rank = _build_proj_equation( - value_shape.rank - 1, bound_dims=1, output_dims=2) + self._value_shape.rank - 1, bound_dims=1, output_dims=2) self._value_dense = einsum_dense.EinsumDense( einsum_equation, output_shape=_get_output_shape(output_rank - 1, @@ -357,7 +380,7 @@ class MultiHeadAttention(Layer): else: output_shape = self._output_shape else: - output_shape = [query_shape[-1]] + output_shape = [self._query_shape[-1]] einsum_equation, bias_axes, output_rank = _build_proj_equation( free_dims, bound_dims=2, output_dims=len(output_shape)) self._output_dense = einsum_dense.EinsumDense( diff --git a/tensorflow/python/keras/layers/multi_head_attention_test.py b/tensorflow/python/keras/layers/multi_head_attention_test.py index 4c957b8973b..d7be7397180 100644 --- a/tensorflow/python/keras/layers/multi_head_attention_test.py +++ b/tensorflow/python/keras/layers/multi_head_attention_test.py @@ -268,5 +268,71 @@ class AttentionSubclassTest(keras_parameterized.TestCase): self.assertEqual(output.shape.as_list(), [None, 40, 80]) +class TestModel(keras.Model): + + def __init__(self): + super(TestModel, self).__init__() + self.attention = multi_head_attention.MultiHeadAttention( + num_heads=3, + key_dim=4, + value_dim=4, + use_bias=True, + dropout=0.0, + output_shape=[12]) + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_config(self): + return {} + + def call(self, x, training=False): + return self.attention(x, x, training=training) + + +@keras_parameterized.run_all_keras_modes(always_skip_v1=True) +class KerasModelSavingTest(keras_parameterized.TestCase): + + def test_keras_saving_subclass(self): + model = TestModel() + query = keras.Input(shape=(40, 80)) + _ = model(query) + model_path = self.get_temp_dir() + "/tmp_model" + keras.models.save_model(model, model_path, save_format="tf") + reloaded_model = keras.models.load_model(model_path) + self.assertEqual( + len(model.trainable_variables), len(reloaded_model.trainable_variables)) + for src_v, loaded_v in zip(model.trainable_variables, + reloaded_model.trainable_variables): + self.assertAllEqual(src_v, loaded_v) + + @parameterized.parameters("h5", "tf") + def test_keras_saving_functional(self, save_format): + model = TestModel() + query = keras.Input(shape=(40, 80)) + output = multi_head_attention.MultiHeadAttention( + num_heads=3, + key_dim=4, + value_dim=4, + use_bias=True, + dropout=0.0)(query, query) + model = keras.Model(inputs=query, outputs=output) + model_path = self.get_temp_dir() + "/tmp_model" + keras.models.save_model(model, model_path, save_format=save_format) + reloaded_model = keras.models.load_model(model_path) + self.assertEqual( + len(model.trainable_variables), len(reloaded_model.trainable_variables)) + for src_v, loaded_v in zip(model.trainable_variables, + reloaded_model.trainable_variables): + self.assertAllEqual(src_v, loaded_v) + + def test_create_without_build(self): + not_intialized_layer = multi_head_attention.MultiHeadAttention( + num_heads=3, key_dim=4, value_dim=4) + multi_head_attention.MultiHeadAttention.from_config( + not_intialized_layer.get_config()) + + if __name__ == "__main__": test.main()