Refactor Keras initializers to rely on core TF initializers; add serialization methods to core TF initializers.
Change: 153403157
This commit is contained in:
parent
ae84106edc
commit
fd561221d2
tensorflow
contrib/keras
python/ops
tools/api/golden
tensorflow.constant_initializer.pbtxttensorflow.ones_initializer.pbtxttensorflow.orthogonal_initializer.pbtxttensorflow.random_normal_initializer.pbtxttensorflow.random_uniform_initializer.pbtxttensorflow.truncated_normal_initializer.pbtxttensorflow.uniform_unit_scaling_initializer.pbtxttensorflow.zeros_initializer.pbtxt
@ -134,7 +134,7 @@ py_library(
|
||||
|
||||
py_test(
|
||||
name = "integration_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["python/keras/integration_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"],
|
||||
|
@ -18,247 +18,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.keras.python.keras import backend as K
|
||||
from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
|
||||
class Initializer(object):
|
||||
"""Initializer base class: all initializers inherit from this class.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_config(self):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class Zeros(Initializer):
|
||||
"""Initializer that generates tensors initialized to 0.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
return K.constant(0, shape=shape, dtype=dtype)
|
||||
|
||||
|
||||
class Ones(Initializer):
|
||||
"""Initializer that generates tensors initialized to 1.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
return K.constant(1, shape=shape, dtype=dtype)
|
||||
|
||||
|
||||
class Constant(Initializer):
|
||||
"""Initializer that generates tensors initialized to a constant value.
|
||||
|
||||
Arguments:
|
||||
value: float; the value of the generator tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, value=0):
|
||||
self.value = value
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
return K.constant(self.value, shape=shape, dtype=dtype)
|
||||
|
||||
def get_config(self):
|
||||
return {'value': self.value}
|
||||
|
||||
|
||||
class RandomNormal(Initializer):
|
||||
"""Initializer that generates tensors with a normal distribution.
|
||||
|
||||
Arguments:
|
||||
mean: a python scalar or a scalar tensor. Mean of the random values
|
||||
to generate.
|
||||
stddev: a python scalar or a scalar tensor. Standard deviation of the
|
||||
random values to generate.
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0., stddev=0.05, seed=None):
|
||||
self.mean = mean
|
||||
self.stddev = stddev
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
return K.random_normal(
|
||||
shape, self.mean, self.stddev, dtype=dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {'mean': self.mean, 'stddev': self.stddev, 'seed': self.seed}
|
||||
|
||||
|
||||
class RandomUniform(Initializer):
|
||||
"""Initializer that generates tensors with a uniform distribution.
|
||||
|
||||
Arguments:
|
||||
minval: A python scalar or a scalar tensor. Lower bound of the range
|
||||
of random values to generate.
|
||||
maxval: A python scalar or a scalar tensor. Upper bound of the range
|
||||
of random values to generate. Defaults to 1 for float types.
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
"""
|
||||
|
||||
def __init__(self, minval=-0.05, maxval=0.05, seed=None):
|
||||
self.minval = minval
|
||||
self.maxval = maxval
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
return K.random_uniform(
|
||||
shape, self.minval, self.maxval, dtype=dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'minval': self.minval,
|
||||
'maxval': self.maxval,
|
||||
'seed': self.seed,
|
||||
}
|
||||
|
||||
|
||||
class TruncatedNormal(Initializer):
|
||||
"""Initializer that generates a truncated normal distribution.
|
||||
|
||||
These values are similar to values from a `RandomNormal`
|
||||
except that values more than two standard deviations from the mean
|
||||
are discarded and re-drawn. This is the recommended initializer for
|
||||
neural network weights and filters.
|
||||
|
||||
Arguments:
|
||||
mean: a python scalar or a scalar tensor. Mean of the random values
|
||||
to generate.
|
||||
stddev: a python scalar or a scalar tensor. Standard deviation of the
|
||||
random values to generate.
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0., stddev=0.05, seed=None):
|
||||
self.mean = mean
|
||||
self.stddev = stddev
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
return K.truncated_normal(
|
||||
shape, self.mean, self.stddev, dtype=dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {'mean': self.mean, 'stddev': self.stddev, 'seed': self.seed}
|
||||
|
||||
|
||||
class VarianceScaling(Initializer):
|
||||
"""Initializer capable of adapting its scale to the shape of weights.
|
||||
|
||||
With `distribution="normal"`, samples are drawn from a truncated normal
|
||||
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
|
||||
|
||||
- number of input units in the weight tensor, if mode = "fan_in"
|
||||
- number of output units, if mode = "fan_out"
|
||||
- average of the numbers of input and output units, if mode = "fan_avg"
|
||||
|
||||
With `distribution="uniform"`,
|
||||
samples are drawn from a uniform distribution
|
||||
within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
|
||||
|
||||
Arguments:
|
||||
scale: Scaling factor (positive float).
|
||||
mode: One of "fan_in", "fan_out", "fan_avg".
|
||||
distribution: Random distribution to use. One of "normal", "uniform".
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
|
||||
Raises:
|
||||
ValueError: In case of an invalid value for the "scale", mode" or
|
||||
"distribution" arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, scale=1.0, mode='fan_in', distribution='normal',
|
||||
seed=None):
|
||||
if scale <= 0.:
|
||||
raise ValueError('`scale` must be a positive float. Got:', scale)
|
||||
mode = mode.lower()
|
||||
if mode not in {'fan_in', 'fan_out', 'fan_avg'}:
|
||||
raise ValueError('Invalid `mode` argument: '
|
||||
'expected on of {"fan_in", "fan_out", "fan_avg"} '
|
||||
'but got', mode)
|
||||
distribution = distribution.lower()
|
||||
if distribution not in {'normal', 'uniform'}:
|
||||
raise ValueError('Invalid `distribution` argument: '
|
||||
'expected one of {"normal", "uniform"} '
|
||||
'but got', distribution)
|
||||
self.scale = scale
|
||||
self.mode = mode
|
||||
self.distribution = distribution
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
fan_in, fan_out = _compute_fans(shape)
|
||||
scale = self.scale
|
||||
if self.mode == 'fan_in':
|
||||
scale /= max(1., fan_in)
|
||||
elif self.mode == 'fan_out':
|
||||
scale /= max(1., fan_out)
|
||||
else:
|
||||
scale /= max(1., float(fan_in + fan_out) / 2)
|
||||
if self.distribution == 'normal':
|
||||
stddev = math.sqrt(scale)
|
||||
return K.truncated_normal(shape, 0., stddev, dtype=dtype, seed=self.seed)
|
||||
else:
|
||||
limit = math.sqrt(3. * scale)
|
||||
return K.random_uniform(shape, -limit, limit, dtype=dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'scale': self.scale,
|
||||
'mode': self.mode,
|
||||
'distribution': self.distribution,
|
||||
'seed': self.seed
|
||||
}
|
||||
|
||||
|
||||
class Orthogonal(Initializer):
|
||||
"""Initializer that generates a random orthogonal matrix.
|
||||
|
||||
Arguments:
|
||||
gain: Multiplicative factor to apply to the orthogonal matrix.
|
||||
seed: A Python integer. Used to seed the random generator.
|
||||
|
||||
References:
|
||||
Saxe et al., http://arxiv.org/abs/1312.6120
|
||||
"""
|
||||
|
||||
def __init__(self, gain=1., seed=None):
|
||||
self.gain = gain
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
num_rows = 1
|
||||
for dim in shape[:-1]:
|
||||
num_rows *= dim
|
||||
num_cols = shape[-1]
|
||||
flat_shape = (num_rows, num_cols)
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
a = np.random.normal(0.0, 1.0, flat_shape)
|
||||
u, _, v = np.linalg.svd(a, full_matrices=False)
|
||||
# Pick the one with the correct shape.
|
||||
q = u if u.shape == flat_shape else v
|
||||
q = q.reshape(shape)
|
||||
return self.gain * q[:shape[0], :shape[1]]
|
||||
|
||||
def get_config(self):
|
||||
return {'gain': self.gain, 'seed': self.seed}
|
||||
from tensorflow.python.ops.init_ops import Constant
|
||||
from tensorflow.python.ops.init_ops import Initializer
|
||||
from tensorflow.python.ops.init_ops import Ones
|
||||
from tensorflow.python.ops.init_ops import Orthogonal
|
||||
from tensorflow.python.ops.init_ops import RandomNormal
|
||||
from tensorflow.python.ops.init_ops import RandomUniform
|
||||
from tensorflow.python.ops.init_ops import TruncatedNormal
|
||||
from tensorflow.python.ops.init_ops import VarianceScaling
|
||||
from tensorflow.python.ops.init_ops import Zeros
|
||||
|
||||
|
||||
class Identity(Initializer):
|
||||
@ -406,47 +179,6 @@ orthogonal = Orthogonal
|
||||
# Utility functions
|
||||
|
||||
|
||||
def _compute_fans(shape, data_format='channels_last'):
|
||||
"""Computes the number of input and output units for a weight shape.
|
||||
|
||||
Arguments:
|
||||
shape: Integer shape tuple.
|
||||
data_format: Image data format to use for convolution kernels.
|
||||
Note that all kernels in Keras are standardized on the
|
||||
`channels_last` ordering (even when inputs are set
|
||||
to `channels_first`).
|
||||
|
||||
Returns:
|
||||
A tuple of scalars, `(fan_in, fan_out)`.
|
||||
|
||||
Raises:
|
||||
ValueError: in case of invalid `data_format` argument.
|
||||
"""
|
||||
shape = tensor_shape.TensorShape(shape).as_list()
|
||||
if len(shape) == 2:
|
||||
fan_in = shape[0]
|
||||
fan_out = shape[1]
|
||||
elif len(shape) in {3, 4, 5}:
|
||||
# Assuming convolution kernels (1D, 2D or 3D).
|
||||
# TH kernel shape: (depth, input_depth, ...)
|
||||
# TF kernel shape: (..., input_depth, depth)
|
||||
if data_format == 'channels_first':
|
||||
receptive_field_size = np.prod(shape[2:])
|
||||
fan_in = shape[1] * receptive_field_size
|
||||
fan_out = shape[0] * receptive_field_size
|
||||
elif data_format == 'channels_last':
|
||||
receptive_field_size = np.prod(shape[:2])
|
||||
fan_in = shape[-2] * receptive_field_size
|
||||
fan_out = shape[-1] * receptive_field_size
|
||||
else:
|
||||
raise ValueError('Invalid data_format: ' + data_format)
|
||||
else:
|
||||
# No specific assumptions.
|
||||
fan_in = math.sqrt(np.prod(shape))
|
||||
fan_out = math.sqrt(np.prod(shape))
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def serialize(initializer):
|
||||
return serialize_keras_object(initializer)
|
||||
|
||||
|
@ -21,121 +21,132 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.keras.python import keras
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _runner(init, shape, target_mean=None, target_std=None,
|
||||
target_max=None, target_min=None):
|
||||
variable = keras.backend.variable(init(shape))
|
||||
output = keras.backend.get_value(variable)
|
||||
lim = 3e-2
|
||||
if target_std is not None:
|
||||
assert abs(output.std() - target_std) < lim, output.std()
|
||||
if target_mean is not None:
|
||||
assert abs(output.mean() - target_mean) < lim, output.mean()
|
||||
if target_max is not None:
|
||||
assert abs(output.max() - target_max) < lim, output.max()
|
||||
if target_min is not None:
|
||||
assert abs(output.min() - target_min) < lim, output.min()
|
||||
|
||||
|
||||
class KerasInitializersTest(test.TestCase):
|
||||
|
||||
def _runner(self, init, shape, target_mean=None, target_std=None,
|
||||
target_max=None, target_min=None):
|
||||
variable = keras.backend.variable(init(shape))
|
||||
output = keras.backend.get_value(variable)
|
||||
lim = 3e-2
|
||||
if target_std is not None:
|
||||
self.assertGreater(lim, abs(output.std() - target_std))
|
||||
if target_mean is not None:
|
||||
self.assertGreater(lim, abs(output.mean() - target_mean))
|
||||
if target_max is not None:
|
||||
self.assertGreater(lim, abs(output.max() - target_max))
|
||||
if target_min is not None:
|
||||
self.assertGreater(lim, abs(output.min() - target_min))
|
||||
|
||||
# Test serialization (assumes deterministic behavior).
|
||||
config = init.get_config()
|
||||
reconstructed_init = init.__class__.from_config(config)
|
||||
variable = keras.backend.variable(reconstructed_init(shape))
|
||||
output_2 = keras.backend.get_value(variable)
|
||||
self.assertAllClose(output, output_2, atol=1e-4)
|
||||
|
||||
def test_uniform(self):
|
||||
tensor_shape = (9, 6, 7)
|
||||
with self.test_session():
|
||||
_runner(keras.initializers.RandomUniform(minval=-1, maxval=1, seed=124),
|
||||
tensor_shape,
|
||||
target_mean=0., target_max=1, target_min=-1)
|
||||
self._runner(keras.initializers.RandomUniform(minval=-1,
|
||||
maxval=1,
|
||||
seed=124),
|
||||
tensor_shape,
|
||||
target_mean=0., target_max=1, target_min=-1)
|
||||
|
||||
def test_normal(self):
|
||||
tensor_shape = (8, 12, 99)
|
||||
with self.test_session():
|
||||
_runner(keras.initializers.RandomNormal(mean=0, stddev=1, seed=153),
|
||||
tensor_shape,
|
||||
target_mean=0., target_std=1)
|
||||
self._runner(keras.initializers.RandomNormal(mean=0, stddev=1, seed=153),
|
||||
tensor_shape,
|
||||
target_mean=0., target_std=1)
|
||||
|
||||
def test_truncated_normal(self):
|
||||
tensor_shape = (12, 99, 7)
|
||||
with self.test_session():
|
||||
_runner(keras.initializers.TruncatedNormal(mean=0, stddev=1, seed=126),
|
||||
tensor_shape,
|
||||
target_mean=0., target_std=None, target_max=2)
|
||||
self._runner(keras.initializers.TruncatedNormal(mean=0,
|
||||
stddev=1,
|
||||
seed=126),
|
||||
tensor_shape,
|
||||
target_mean=0., target_std=None, target_max=2)
|
||||
|
||||
def test_constant(self):
|
||||
tensor_shape = (5, 6, 4)
|
||||
with self.test_session():
|
||||
_runner(keras.initializers.Constant(2), tensor_shape,
|
||||
target_mean=2, target_max=2, target_min=2)
|
||||
self._runner(keras.initializers.Constant(2), tensor_shape,
|
||||
target_mean=2, target_max=2, target_min=2)
|
||||
|
||||
def test_lecun_uniform(self):
|
||||
tensor_shape = (5, 6, 4, 2)
|
||||
with self.test_session():
|
||||
fan_in, _ = keras.initializers._compute_fans(tensor_shape)
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
scale = np.sqrt(3. / fan_in)
|
||||
_runner(keras.initializers.lecun_uniform(seed=123), tensor_shape,
|
||||
target_mean=0., target_max=scale, target_min=-scale)
|
||||
self._runner(keras.initializers.lecun_uniform(seed=123), tensor_shape,
|
||||
target_mean=0., target_max=scale, target_min=-scale)
|
||||
|
||||
def test_glorot_uniform(self):
|
||||
tensor_shape = (5, 6, 4, 2)
|
||||
with self.test_session():
|
||||
fan_in, fan_out = keras.initializers._compute_fans(tensor_shape)
|
||||
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
|
||||
scale = np.sqrt(6. / (fan_in + fan_out))
|
||||
_runner(keras.initializers.glorot_uniform(seed=123), tensor_shape,
|
||||
target_mean=0., target_max=scale, target_min=-scale)
|
||||
self._runner(keras.initializers.glorot_uniform(seed=123), tensor_shape,
|
||||
target_mean=0., target_max=scale, target_min=-scale)
|
||||
|
||||
def test_he_uniform(self):
|
||||
tensor_shape = (5, 6, 4, 2)
|
||||
with self.test_session():
|
||||
fan_in, _ = keras.initializers._compute_fans(tensor_shape)
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
scale = np.sqrt(6. / fan_in)
|
||||
_runner(keras.initializers.he_uniform(seed=123), tensor_shape,
|
||||
target_mean=0., target_max=scale, target_min=-scale)
|
||||
self._runner(keras.initializers.he_uniform(seed=123), tensor_shape,
|
||||
target_mean=0., target_max=scale, target_min=-scale)
|
||||
|
||||
def test_glorot_normal(self):
|
||||
tensor_shape = (5, 6, 4, 2)
|
||||
with self.test_session():
|
||||
fan_in, fan_out = keras.initializers._compute_fans(tensor_shape)
|
||||
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
|
||||
scale = np.sqrt(2. / (fan_in + fan_out))
|
||||
_runner(keras.initializers.glorot_normal(seed=123), tensor_shape,
|
||||
target_mean=0., target_std=None, target_max=2 * scale)
|
||||
self._runner(keras.initializers.glorot_normal(seed=123), tensor_shape,
|
||||
target_mean=0., target_std=None, target_max=2 * scale)
|
||||
|
||||
def test_he_normal(self):
|
||||
tensor_shape = (5, 6, 4, 2)
|
||||
with self.test_session():
|
||||
fan_in, _ = keras.initializers._compute_fans(tensor_shape)
|
||||
fan_in, _ = init_ops._compute_fans(tensor_shape)
|
||||
scale = np.sqrt(2. / fan_in)
|
||||
_runner(keras.initializers.he_normal(seed=123), tensor_shape,
|
||||
target_mean=0., target_std=None, target_max=2 * scale)
|
||||
self._runner(keras.initializers.he_normal(seed=123), tensor_shape,
|
||||
target_mean=0., target_std=None, target_max=2 * scale)
|
||||
|
||||
def test_orthogonal(self):
|
||||
tensor_shape = (7, 8)
|
||||
tensor_shape = (10, 10)
|
||||
with self.test_session():
|
||||
_runner(keras.initializers.orthogonal(seed=123), tensor_shape,
|
||||
target_mean=0.)
|
||||
self._runner(keras.initializers.orthogonal(seed=123), tensor_shape,
|
||||
target_mean=0.)
|
||||
|
||||
def test_identity(self):
|
||||
with self.test_session():
|
||||
tensor_shape = (3, 4, 5)
|
||||
with self.assertRaises(ValueError):
|
||||
_runner(keras.initializers.identity(), tensor_shape,
|
||||
target_mean=1. / tensor_shape[0], target_max=1.)
|
||||
self._runner(keras.initializers.identity(), tensor_shape,
|
||||
target_mean=1. / tensor_shape[0], target_max=1.)
|
||||
|
||||
tensor_shape = (3, 3)
|
||||
_runner(keras.initializers.identity(), tensor_shape,
|
||||
target_mean=1. / tensor_shape[0], target_max=1.)
|
||||
self._runner(keras.initializers.identity(), tensor_shape,
|
||||
target_mean=1. / tensor_shape[0], target_max=1.)
|
||||
|
||||
def test_zero(self):
|
||||
tensor_shape = (4, 5)
|
||||
with self.test_session():
|
||||
_runner(keras.initializers.zeros(), tensor_shape,
|
||||
target_mean=0., target_max=0.)
|
||||
self._runner(keras.initializers.zeros(), tensor_shape,
|
||||
target_mean=0., target_max=0.)
|
||||
|
||||
def test_one(self):
|
||||
tensor_shape = (4, 5)
|
||||
with self.test_session():
|
||||
_runner(keras.initializers.ones(), tensor_shape,
|
||||
target_mean=1., target_max=1.)
|
||||
self._runner(keras.initializers.ones(), tensor_shape,
|
||||
target_mean=1., target_max=1.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -33,13 +33,13 @@ class KerasIntegrationTest(test.TestCase):
|
||||
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
|
||||
train_samples=200,
|
||||
test_samples=100,
|
||||
input_shape=(8,),
|
||||
input_shape=(10,),
|
||||
num_classes=2)
|
||||
y_train = keras.utils.to_categorical(y_train)
|
||||
y_test = keras.utils.to_categorical(y_test)
|
||||
|
||||
model = keras.models.Sequential([
|
||||
keras.layers.Dense(8,
|
||||
keras.layers.Dense(16,
|
||||
activation='relu',
|
||||
input_shape=x_train.shape[1:]),
|
||||
keras.layers.Dropout(0.1),
|
||||
@ -59,13 +59,13 @@ class KerasIntegrationTest(test.TestCase):
|
||||
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
|
||||
train_samples=200,
|
||||
test_samples=100,
|
||||
input_shape=(8,),
|
||||
input_shape=(10,),
|
||||
num_classes=2)
|
||||
y_train = keras.utils.to_categorical(y_train)
|
||||
y_test = keras.utils.to_categorical(y_test)
|
||||
|
||||
inputs = keras.layers.Input(shape=x_train.shape[1:])
|
||||
x = keras.layers.Dense(8, activation='relu')(inputs)
|
||||
x = keras.layers.Dense(16, activation='relu')(inputs)
|
||||
x = keras.layers.Dropout(0.1)(x)
|
||||
outputs = keras.layers.Dense(y_train.shape[-1], activation='softmax')(x)
|
||||
|
||||
|
@ -41,7 +41,7 @@ def _test_optimizer(optimizer, target=0.75):
|
||||
input_shape=(10,),
|
||||
num_classes=2)
|
||||
y_train = keras.utils.to_categorical(y_train)
|
||||
model = _get_model(x_train.shape[1], 10, y_train.shape[1])
|
||||
model = _get_model(x_train.shape[1], 20, y_train.shape[1])
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=optimizer,
|
||||
metrics=['accuracy'])
|
||||
|
@ -49,30 +49,65 @@ class Initializer(object):
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_config(self):
|
||||
"""Returns the configuration of the initializer as a JSON-serializable dict.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable Python dict.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Instantiates an initializer from a configuration dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
initializer = RandomUniform(-1, 1)
|
||||
config = initializer.get_config()
|
||||
initializer = RandomUniform.from_config(config)
|
||||
```
|
||||
|
||||
Arguments:
|
||||
config: A Python dictionary.
|
||||
It will typically be the output of `get_config`.
|
||||
|
||||
Returns:
|
||||
An Initializer instance.
|
||||
"""
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class Zeros(Initializer):
|
||||
"""Initializer that generates tensors initialized to 0."""
|
||||
|
||||
def __init__(self, dtype=dtypes.float32):
|
||||
self.dtype = dtype
|
||||
self.dtype = dtypes.as_dtype(dtype)
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
return array_ops.zeros(shape, dtype)
|
||||
|
||||
def get_config(self):
|
||||
return {"dtype": self.dtype.name}
|
||||
|
||||
|
||||
class Ones(Initializer):
|
||||
"""Initializer that generates tensors initialized to 1."""
|
||||
|
||||
def __init__(self, dtype=dtypes.float32):
|
||||
self.dtype = dtype
|
||||
self.dtype = dtypes.as_dtype(dtype)
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
return array_ops.ones(shape, dtype)
|
||||
|
||||
def get_config(self):
|
||||
return {"dtype": self.dtype.name}
|
||||
|
||||
|
||||
class Constant(Initializer):
|
||||
"""Initializer that generates tensors with constant values.
|
||||
@ -151,14 +186,27 @@ class Constant(Initializer):
|
||||
|
||||
def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
|
||||
self.value = value
|
||||
self.dtype = dtype
|
||||
self.verify_shape = verify_shape
|
||||
self.dtype = dtypes.as_dtype(dtype)
|
||||
self._verify_shape = verify_shape
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
def __call__(self, shape,
|
||||
dtype=None,
|
||||
partition_info=None,
|
||||
verify_shape=None):
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
if verify_shape is None:
|
||||
verify_shape = self._verify_shape
|
||||
return constant_op.constant(self.value, dtype=dtype, shape=shape,
|
||||
verify_shape=self.verify_shape)
|
||||
verify_shape=verify_shape)
|
||||
|
||||
def get_config(self):
|
||||
# We don't include `verify_shape` for compatibility with Keras.
|
||||
# `verify_shape` should be passed as an argument to `__call__` rather
|
||||
# than as a constructor argument: conceptually it isn't a property
|
||||
# of the initializer.
|
||||
return {"value": self.value,
|
||||
"dtype": self.dtype.name}
|
||||
|
||||
|
||||
class RandomUniform(Initializer):
|
||||
@ -179,7 +227,7 @@ class RandomUniform(Initializer):
|
||||
self.minval = minval
|
||||
self.maxval = maxval
|
||||
self.seed = seed
|
||||
self.dtype = dtype
|
||||
self.dtype = dtypes.as_dtype(dtype)
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
@ -187,6 +235,12 @@ class RandomUniform(Initializer):
|
||||
return random_ops.random_uniform(shape, self.minval, self.maxval,
|
||||
dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"minval": self.minval,
|
||||
"maxval": self.maxval,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
|
||||
|
||||
class RandomNormal(Initializer):
|
||||
"""Initializer that generates tensors with a normal distribution.
|
||||
@ -206,7 +260,7 @@ class RandomNormal(Initializer):
|
||||
self.mean = mean
|
||||
self.stddev = stddev
|
||||
self.seed = seed
|
||||
self.dtype = _assert_float_dtype(dtype)
|
||||
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
@ -214,6 +268,12 @@ class RandomNormal(Initializer):
|
||||
return random_ops.random_normal(shape, self.mean, self.stddev,
|
||||
dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"mean": self.mean,
|
||||
"stddev": self.stddev,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
|
||||
|
||||
class TruncatedNormal(Initializer):
|
||||
"""Initializer that generates a truncated normal distribution.
|
||||
@ -238,7 +298,7 @@ class TruncatedNormal(Initializer):
|
||||
self.mean = mean
|
||||
self.stddev = stddev
|
||||
self.seed = seed
|
||||
self.dtype = _assert_float_dtype(dtype)
|
||||
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
@ -246,6 +306,12 @@ class TruncatedNormal(Initializer):
|
||||
return random_ops.truncated_normal(shape, self.mean, self.stddev,
|
||||
dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"mean": self.mean,
|
||||
"stddev": self.stddev,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
|
||||
|
||||
class UniformUnitScaling(Initializer):
|
||||
"""Initializer that generates tensors without scaling variance.
|
||||
@ -277,7 +343,7 @@ class UniformUnitScaling(Initializer):
|
||||
def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32):
|
||||
self.factor = factor
|
||||
self.seed = seed
|
||||
self.dtype = _assert_float_dtype(dtype)
|
||||
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
@ -298,6 +364,11 @@ class UniformUnitScaling(Initializer):
|
||||
return random_ops.random_uniform(shape, -max_val, max_val,
|
||||
dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"factor": self.factor,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
|
||||
|
||||
class VarianceScaling(Initializer):
|
||||
"""Initializer capable of adapting its scale to the shape of weights tensors.
|
||||
@ -342,7 +413,7 @@ class VarianceScaling(Initializer):
|
||||
self.mode = mode
|
||||
self.distribution = distribution
|
||||
self.seed = seed
|
||||
self.dtype = _assert_float_dtype(dtype)
|
||||
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
@ -367,6 +438,13 @@ class VarianceScaling(Initializer):
|
||||
return random_ops.random_uniform(shape, -limit, limit,
|
||||
dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"scale": self.scale,
|
||||
"mode": self.mode,
|
||||
"distribution": self.distribution,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
|
||||
|
||||
class Orthogonal(Initializer):
|
||||
"""Initializer that generates an orthogonal matrix.
|
||||
@ -388,9 +466,9 @@ class Orthogonal(Initializer):
|
||||
for behavior.
|
||||
"""
|
||||
|
||||
def __init__(self, gain=1.0, dtype=dtypes.float32, seed=None):
|
||||
def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
|
||||
self.gain = gain
|
||||
self.dtype = _assert_float_dtype(dtype)
|
||||
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
@ -421,6 +499,11 @@ class Orthogonal(Initializer):
|
||||
q = array_ops.transpose(v)
|
||||
return self.gain * array_ops.reshape(q, shape)
|
||||
|
||||
def get_config(self):
|
||||
return {"gain": self.gain,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
|
||||
|
||||
# Aliases.
|
||||
|
||||
|
@ -7,4 +7,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'float32\'>\", \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -7,4 +7,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,14 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'gain\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -7,4 +7,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -7,4 +7,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -7,4 +7,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -7,4 +7,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'factor\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -7,4 +7,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user