Make RandomFourierFeatures state saveable.
PiperOrigin-RevId: 313112328 Change-Id: I21c8881b84d8d40e90e3dc82bb38154bc928b5f4
This commit is contained in:
parent
5dbc34f565
commit
256332096c
|
@ -191,15 +191,15 @@ class RandomFourierFeatures(base_layer.Layer):
|
|||
kernel_initializer = _get_random_features_initializer(
|
||||
self.kernel_initializer, shape=(input_dim, self.output_dim))
|
||||
|
||||
unscaled_kernel = self.add_weight(
|
||||
name='unscaled_random_features',
|
||||
self.unscaled_kernel = self.add_weight(
|
||||
name='unscaled_kernel',
|
||||
shape=(input_dim, self.output_dim),
|
||||
dtype=dtypes.float32,
|
||||
initializer=kernel_initializer,
|
||||
trainable=False)
|
||||
|
||||
self.bias = self.add_weight(
|
||||
name='random_features_bias',
|
||||
name='bias',
|
||||
shape=(self.output_dim,),
|
||||
dtype=dtypes.float32,
|
||||
initializer=init_ops.random_uniform_initializer(
|
||||
|
@ -208,20 +208,20 @@ class RandomFourierFeatures(base_layer.Layer):
|
|||
|
||||
if self.scale is None:
|
||||
self.scale = _get_default_scale(self.kernel_initializer, input_dim)
|
||||
scale = self.add_weight(
|
||||
name='random_features_scale',
|
||||
self.kernel_scale = self.add_weight(
|
||||
name='kernel_scale',
|
||||
shape=(1,),
|
||||
dtype=dtypes.float32,
|
||||
initializer=init_ops.constant_initializer(self.scale),
|
||||
trainable=True,
|
||||
constraint='NonNeg')
|
||||
self.kernel = (1.0 / scale) * unscaled_kernel
|
||||
super(RandomFourierFeatures, self).build(input_shape)
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = ops.convert_to_tensor_v2(inputs, dtype=self.dtype)
|
||||
inputs = gen_math_ops.cast(inputs, dtypes.float32)
|
||||
outputs = gen_math_ops.mat_mul(inputs, self.kernel)
|
||||
kernel = (1.0 / self.kernel_scale) * self.unscaled_kernel
|
||||
outputs = gen_math_ops.mat_mul(inputs, kernel)
|
||||
outputs = nn.bias_add(outputs, self.bias)
|
||||
return gen_math_ops.cos(outputs)
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
|||
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
@ -35,7 +37,10 @@ from tensorflow.python.keras import backend as keras_backend
|
|||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import kernelized as kernel_layers
|
||||
from tensorflow.python.keras.saving import save
|
||||
from tensorflow.python.keras.utils import kernelized_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
|
@ -65,6 +70,22 @@ class RandomFourierFeaturesTest(test.TestCase, parameterized.TestCase):
|
|||
else:
|
||||
self.assertAllClose(expected, actual, atol=atol)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_state_saving_and_loading(self):
|
||||
input_data = np.random.random((1, 2))
|
||||
rff_layer = kernel_layers.RandomFourierFeatures(output_dim=10, scale=3.0)
|
||||
inputs = input_layer.Input((2,))
|
||||
outputs = rff_layer(inputs)
|
||||
model = training.Model(inputs, outputs)
|
||||
output_data = model.predict(input_data)
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
saved_model_dir = os.path.join(temp_dir, 'rff_model')
|
||||
model.save(saved_model_dir)
|
||||
new_model = save.load_model(saved_model_dir)
|
||||
new_output_data = new_model.predict(input_data)
|
||||
self.assertAllClose(output_data, new_output_data, atol=1e-4)
|
||||
|
||||
def test_invalid_output_dim(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r'`output_dim` should be a positive integer. Given: -3.'):
|
||||
|
@ -246,7 +267,7 @@ class RandomFourierFeaturesTest(test.TestCase, parameterized.TestCase):
|
|||
num_trainable_vars = 1 if trainable else 0
|
||||
self.assertLen(rff_layer.trainable_variables, num_trainable_vars)
|
||||
if trainable:
|
||||
self.assertEqual('random_fourier_features/random_features_scale:0',
|
||||
self.assertEqual('random_fourier_features/kernel_scale:0',
|
||||
rff_layer.trainable_variables[0].name)
|
||||
self.assertLen(rff_layer.non_trainable_variables, 3 - num_trainable_vars)
|
||||
|
||||
|
|
Loading…
Reference in New Issue