Make RandomFourierFeatures state saveable.

PiperOrigin-RevId: 313112328
Change-Id: I21c8881b84d8d40e90e3dc82bb38154bc928b5f4
This commit is contained in:
Francois Chollet 2020-05-25 17:04:14 -07:00 committed by TensorFlower Gardener
parent 5dbc34f565
commit 256332096c
2 changed files with 29 additions and 8 deletions

View File

@ -191,15 +191,15 @@ class RandomFourierFeatures(base_layer.Layer):
kernel_initializer = _get_random_features_initializer(
self.kernel_initializer, shape=(input_dim, self.output_dim))
unscaled_kernel = self.add_weight(
name='unscaled_random_features',
self.unscaled_kernel = self.add_weight(
name='unscaled_kernel',
shape=(input_dim, self.output_dim),
dtype=dtypes.float32,
initializer=kernel_initializer,
trainable=False)
self.bias = self.add_weight(
name='random_features_bias',
name='bias',
shape=(self.output_dim,),
dtype=dtypes.float32,
initializer=init_ops.random_uniform_initializer(
@ -208,20 +208,20 @@ class RandomFourierFeatures(base_layer.Layer):
if self.scale is None:
self.scale = _get_default_scale(self.kernel_initializer, input_dim)
scale = self.add_weight(
name='random_features_scale',
self.kernel_scale = self.add_weight(
name='kernel_scale',
shape=(1,),
dtype=dtypes.float32,
initializer=init_ops.constant_initializer(self.scale),
trainable=True,
constraint='NonNeg')
self.kernel = (1.0 / scale) * unscaled_kernel
super(RandomFourierFeatures, self).build(input_shape)
def call(self, inputs):
inputs = ops.convert_to_tensor_v2(inputs, dtype=self.dtype)
inputs = gen_math_ops.cast(inputs, dtypes.float32)
outputs = gen_math_ops.mat_mul(inputs, self.kernel)
kernel = (1.0 / self.kernel_scale) * self.unscaled_kernel
outputs = gen_math_ops.mat_mul(inputs, kernel)
outputs = nn.bias_add(outputs, self.bias)
return gen_math_ops.cos(outputs)

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import functools
import math
import os
import shutil
from absl.testing import parameterized
import numpy as np
@ -35,7 +37,10 @@ from tensorflow.python.keras import backend as keras_backend
from tensorflow.python.keras import combinations
from tensorflow.python.keras import initializers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import kernelized as kernel_layers
from tensorflow.python.keras.saving import save
from tensorflow.python.keras.utils import kernelized_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@ -65,6 +70,22 @@ class RandomFourierFeaturesTest(test.TestCase, parameterized.TestCase):
else:
self.assertAllClose(expected, actual, atol=atol)
@test_util.run_v2_only
def test_state_saving_and_loading(self):
input_data = np.random.random((1, 2))
rff_layer = kernel_layers.RandomFourierFeatures(output_dim=10, scale=3.0)
inputs = input_layer.Input((2,))
outputs = rff_layer(inputs)
model = training.Model(inputs, outputs)
output_data = model.predict(input_data)
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
saved_model_dir = os.path.join(temp_dir, 'rff_model')
model.save(saved_model_dir)
new_model = save.load_model(saved_model_dir)
new_output_data = new_model.predict(input_data)
self.assertAllClose(output_data, new_output_data, atol=1e-4)
def test_invalid_output_dim(self):
with self.assertRaisesRegexp(
ValueError, r'`output_dim` should be a positive integer. Given: -3.'):
@ -246,7 +267,7 @@ class RandomFourierFeaturesTest(test.TestCase, parameterized.TestCase):
num_trainable_vars = 1 if trainable else 0
self.assertLen(rff_layer.trainable_variables, num_trainable_vars)
if trainable:
self.assertEqual('random_fourier_features/random_features_scale:0',
self.assertEqual('random_fourier_features/kernel_scale:0',
rff_layer.trainable_variables[0].name)
self.assertLen(rff_layer.non_trainable_variables, 3 - num_trainable_vars)