From 256332096c08e67ecf080cae457b8d5287e241cc Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 25 May 2020 17:04:14 -0700 Subject: [PATCH] Make RandomFourierFeatures state saveable. PiperOrigin-RevId: 313112328 Change-Id: I21c8881b84d8d40e90e3dc82bb38154bc928b5f4 --- tensorflow/python/keras/layers/kernelized.py | 14 +++++------ .../python/keras/layers/kernelized_test.py | 23 ++++++++++++++++++- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/layers/kernelized.py b/tensorflow/python/keras/layers/kernelized.py index ce53334ebc7..5f401899bec 100644 --- a/tensorflow/python/keras/layers/kernelized.py +++ b/tensorflow/python/keras/layers/kernelized.py @@ -191,15 +191,15 @@ class RandomFourierFeatures(base_layer.Layer): kernel_initializer = _get_random_features_initializer( self.kernel_initializer, shape=(input_dim, self.output_dim)) - unscaled_kernel = self.add_weight( - name='unscaled_random_features', + self.unscaled_kernel = self.add_weight( + name='unscaled_kernel', shape=(input_dim, self.output_dim), dtype=dtypes.float32, initializer=kernel_initializer, trainable=False) self.bias = self.add_weight( - name='random_features_bias', + name='bias', shape=(self.output_dim,), dtype=dtypes.float32, initializer=init_ops.random_uniform_initializer( @@ -208,20 +208,20 @@ class RandomFourierFeatures(base_layer.Layer): if self.scale is None: self.scale = _get_default_scale(self.kernel_initializer, input_dim) - scale = self.add_weight( - name='random_features_scale', + self.kernel_scale = self.add_weight( + name='kernel_scale', shape=(1,), dtype=dtypes.float32, initializer=init_ops.constant_initializer(self.scale), trainable=True, constraint='NonNeg') - self.kernel = (1.0 / scale) * unscaled_kernel super(RandomFourierFeatures, self).build(input_shape) def call(self, inputs): inputs = ops.convert_to_tensor_v2(inputs, dtype=self.dtype) inputs = gen_math_ops.cast(inputs, dtypes.float32) - outputs = gen_math_ops.mat_mul(inputs, self.kernel) + kernel = (1.0 / self.kernel_scale) * self.unscaled_kernel + outputs = gen_math_ops.mat_mul(inputs, kernel) outputs = nn.bias_add(outputs, self.bias) return gen_math_ops.cos(outputs) diff --git a/tensorflow/python/keras/layers/kernelized_test.py b/tensorflow/python/keras/layers/kernelized_test.py index edb58f77868..a6a9d88423f 100644 --- a/tensorflow/python/keras/layers/kernelized_test.py +++ b/tensorflow/python/keras/layers/kernelized_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import functools import math +import os +import shutil from absl.testing import parameterized import numpy as np @@ -35,7 +37,10 @@ from tensorflow.python.keras import backend as keras_backend from tensorflow.python.keras import combinations from tensorflow.python.keras import initializers from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras.engine import input_layer +from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import kernelized as kernel_layers +from tensorflow.python.keras.saving import save from tensorflow.python.keras.utils import kernelized_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -65,6 +70,22 @@ class RandomFourierFeaturesTest(test.TestCase, parameterized.TestCase): else: self.assertAllClose(expected, actual, atol=atol) + @test_util.run_v2_only + def test_state_saving_and_loading(self): + input_data = np.random.random((1, 2)) + rff_layer = kernel_layers.RandomFourierFeatures(output_dim=10, scale=3.0) + inputs = input_layer.Input((2,)) + outputs = rff_layer(inputs) + model = training.Model(inputs, outputs) + output_data = model.predict(input_data) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + saved_model_dir = os.path.join(temp_dir, 'rff_model') + model.save(saved_model_dir) + new_model = save.load_model(saved_model_dir) + new_output_data = new_model.predict(input_data) + self.assertAllClose(output_data, new_output_data, atol=1e-4) + def test_invalid_output_dim(self): with self.assertRaisesRegexp( ValueError, r'`output_dim` should be a positive integer. Given: -3.'): @@ -246,7 +267,7 @@ class RandomFourierFeaturesTest(test.TestCase, parameterized.TestCase): num_trainable_vars = 1 if trainable else 0 self.assertLen(rff_layer.trainable_variables, num_trainable_vars) if trainable: - self.assertEqual('random_fourier_features/random_features_scale:0', + self.assertEqual('random_fourier_features/kernel_scale:0', rff_layer.trainable_variables[0].name) self.assertLen(rff_layer.non_trainable_variables, 3 - num_trainable_vars)