diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 438e2056c69..5166ba37a35 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -134,7 +134,7 @@ py_library( py_test( name = "integration_test", - size = "small", + size = "medium", srcs = ["python/keras/integration_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], diff --git a/tensorflow/contrib/keras/python/keras/initializers.py b/tensorflow/contrib/keras/python/keras/initializers.py index f9cb35e171e..b0b71e7cb4b 100644 --- a/tensorflow/contrib/keras/python/keras/initializers.py +++ b/tensorflow/contrib/keras/python/keras/initializers.py @@ -18,247 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as np import six -from tensorflow.contrib.keras.python.keras import backend as K from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.framework import tensor_shape - - -class Initializer(object): - """Initializer base class: all initializers inherit from this class. - """ - - def __call__(self, shape, dtype=None): - raise NotImplementedError - - def get_config(self): - return {} - - @classmethod - def from_config(cls, config): - return cls(**config) - - -class Zeros(Initializer): - """Initializer that generates tensors initialized to 0. - """ - - def __call__(self, shape, dtype=None): - return K.constant(0, shape=shape, dtype=dtype) - - -class Ones(Initializer): - """Initializer that generates tensors initialized to 1. - """ - - def __call__(self, shape, dtype=None): - return K.constant(1, shape=shape, dtype=dtype) - - -class Constant(Initializer): - """Initializer that generates tensors initialized to a constant value. - - Arguments: - value: float; the value of the generator tensors. - """ - - def __init__(self, value=0): - self.value = value - - def __call__(self, shape, dtype=None): - return K.constant(self.value, shape=shape, dtype=dtype) - - def get_config(self): - return {'value': self.value} - - -class RandomNormal(Initializer): - """Initializer that generates tensors with a normal distribution. - - Arguments: - mean: a python scalar or a scalar tensor. Mean of the random values - to generate. - stddev: a python scalar or a scalar tensor. Standard deviation of the - random values to generate. - seed: A Python integer. Used to seed the random generator. - """ - - def __init__(self, mean=0., stddev=0.05, seed=None): - self.mean = mean - self.stddev = stddev - self.seed = seed - - def __call__(self, shape, dtype=None): - return K.random_normal( - shape, self.mean, self.stddev, dtype=dtype, seed=self.seed) - - def get_config(self): - return {'mean': self.mean, 'stddev': self.stddev, 'seed': self.seed} - - -class RandomUniform(Initializer): - """Initializer that generates tensors with a uniform distribution. - - Arguments: - minval: A python scalar or a scalar tensor. Lower bound of the range - of random values to generate. - maxval: A python scalar or a scalar tensor. Upper bound of the range - of random values to generate. Defaults to 1 for float types. - seed: A Python integer. Used to seed the random generator. - """ - - def __init__(self, minval=-0.05, maxval=0.05, seed=None): - self.minval = minval - self.maxval = maxval - self.seed = seed - - def __call__(self, shape, dtype=None): - return K.random_uniform( - shape, self.minval, self.maxval, dtype=dtype, seed=self.seed) - - def get_config(self): - return { - 'minval': self.minval, - 'maxval': self.maxval, - 'seed': self.seed, - } - - -class TruncatedNormal(Initializer): - """Initializer that generates a truncated normal distribution. - - These values are similar to values from a `RandomNormal` - except that values more than two standard deviations from the mean - are discarded and re-drawn. This is the recommended initializer for - neural network weights and filters. - - Arguments: - mean: a python scalar or a scalar tensor. Mean of the random values - to generate. - stddev: a python scalar or a scalar tensor. Standard deviation of the - random values to generate. - seed: A Python integer. Used to seed the random generator. - """ - - def __init__(self, mean=0., stddev=0.05, seed=None): - self.mean = mean - self.stddev = stddev - self.seed = seed - - def __call__(self, shape, dtype=None): - return K.truncated_normal( - shape, self.mean, self.stddev, dtype=dtype, seed=self.seed) - - def get_config(self): - return {'mean': self.mean, 'stddev': self.stddev, 'seed': self.seed} - - -class VarianceScaling(Initializer): - """Initializer capable of adapting its scale to the shape of weights. - - With `distribution="normal"`, samples are drawn from a truncated normal - distribution centered on zero, with `stddev = sqrt(scale / n)` where n is: - - - number of input units in the weight tensor, if mode = "fan_in" - - number of output units, if mode = "fan_out" - - average of the numbers of input and output units, if mode = "fan_avg" - - With `distribution="uniform"`, - samples are drawn from a uniform distribution - within [-limit, limit], with `limit = sqrt(3 * scale / n)`. - - Arguments: - scale: Scaling factor (positive float). - mode: One of "fan_in", "fan_out", "fan_avg". - distribution: Random distribution to use. One of "normal", "uniform". - seed: A Python integer. Used to seed the random generator. - - Raises: - ValueError: In case of an invalid value for the "scale", mode" or - "distribution" arguments. - """ - - def __init__(self, scale=1.0, mode='fan_in', distribution='normal', - seed=None): - if scale <= 0.: - raise ValueError('`scale` must be a positive float. Got:', scale) - mode = mode.lower() - if mode not in {'fan_in', 'fan_out', 'fan_avg'}: - raise ValueError('Invalid `mode` argument: ' - 'expected on of {"fan_in", "fan_out", "fan_avg"} ' - 'but got', mode) - distribution = distribution.lower() - if distribution not in {'normal', 'uniform'}: - raise ValueError('Invalid `distribution` argument: ' - 'expected one of {"normal", "uniform"} ' - 'but got', distribution) - self.scale = scale - self.mode = mode - self.distribution = distribution - self.seed = seed - - def __call__(self, shape, dtype=None): - fan_in, fan_out = _compute_fans(shape) - scale = self.scale - if self.mode == 'fan_in': - scale /= max(1., fan_in) - elif self.mode == 'fan_out': - scale /= max(1., fan_out) - else: - scale /= max(1., float(fan_in + fan_out) / 2) - if self.distribution == 'normal': - stddev = math.sqrt(scale) - return K.truncated_normal(shape, 0., stddev, dtype=dtype, seed=self.seed) - else: - limit = math.sqrt(3. * scale) - return K.random_uniform(shape, -limit, limit, dtype=dtype, seed=self.seed) - - def get_config(self): - return { - 'scale': self.scale, - 'mode': self.mode, - 'distribution': self.distribution, - 'seed': self.seed - } - - -class Orthogonal(Initializer): - """Initializer that generates a random orthogonal matrix. - - Arguments: - gain: Multiplicative factor to apply to the orthogonal matrix. - seed: A Python integer. Used to seed the random generator. - - References: - Saxe et al., http://arxiv.org/abs/1312.6120 - """ - - def __init__(self, gain=1., seed=None): - self.gain = gain - self.seed = seed - - def __call__(self, shape, dtype=None): - num_rows = 1 - for dim in shape[:-1]: - num_rows *= dim - num_cols = shape[-1] - flat_shape = (num_rows, num_cols) - if self.seed is not None: - np.random.seed(self.seed) - a = np.random.normal(0.0, 1.0, flat_shape) - u, _, v = np.linalg.svd(a, full_matrices=False) - # Pick the one with the correct shape. - q = u if u.shape == flat_shape else v - q = q.reshape(shape) - return self.gain * q[:shape[0], :shape[1]] - - def get_config(self): - return {'gain': self.gain, 'seed': self.seed} +from tensorflow.python.ops.init_ops import Constant +from tensorflow.python.ops.init_ops import Initializer +from tensorflow.python.ops.init_ops import Ones +from tensorflow.python.ops.init_ops import Orthogonal +from tensorflow.python.ops.init_ops import RandomNormal +from tensorflow.python.ops.init_ops import RandomUniform +from tensorflow.python.ops.init_ops import TruncatedNormal +from tensorflow.python.ops.init_ops import VarianceScaling +from tensorflow.python.ops.init_ops import Zeros class Identity(Initializer): @@ -406,47 +179,6 @@ orthogonal = Orthogonal # Utility functions -def _compute_fans(shape, data_format='channels_last'): - """Computes the number of input and output units for a weight shape. - - Arguments: - shape: Integer shape tuple. - data_format: Image data format to use for convolution kernels. - Note that all kernels in Keras are standardized on the - `channels_last` ordering (even when inputs are set - to `channels_first`). - - Returns: - A tuple of scalars, `(fan_in, fan_out)`. - - Raises: - ValueError: in case of invalid `data_format` argument. - """ - shape = tensor_shape.TensorShape(shape).as_list() - if len(shape) == 2: - fan_in = shape[0] - fan_out = shape[1] - elif len(shape) in {3, 4, 5}: - # Assuming convolution kernels (1D, 2D or 3D). - # TH kernel shape: (depth, input_depth, ...) - # TF kernel shape: (..., input_depth, depth) - if data_format == 'channels_first': - receptive_field_size = np.prod(shape[2:]) - fan_in = shape[1] * receptive_field_size - fan_out = shape[0] * receptive_field_size - elif data_format == 'channels_last': - receptive_field_size = np.prod(shape[:2]) - fan_in = shape[-2] * receptive_field_size - fan_out = shape[-1] * receptive_field_size - else: - raise ValueError('Invalid data_format: ' + data_format) - else: - # No specific assumptions. - fan_in = math.sqrt(np.prod(shape)) - fan_out = math.sqrt(np.prod(shape)) - return fan_in, fan_out - - def serialize(initializer): return serialize_keras_object(initializer) diff --git a/tensorflow/contrib/keras/python/keras/initializers_test.py b/tensorflow/contrib/keras/python/keras/initializers_test.py index 7436fbb3904..c9f50c28eae 100644 --- a/tensorflow/contrib/keras/python/keras/initializers_test.py +++ b/tensorflow/contrib/keras/python/keras/initializers_test.py @@ -21,121 +21,132 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.keras.python import keras +from tensorflow.python.ops import init_ops from tensorflow.python.platform import test -def _runner(init, shape, target_mean=None, target_std=None, - target_max=None, target_min=None): - variable = keras.backend.variable(init(shape)) - output = keras.backend.get_value(variable) - lim = 3e-2 - if target_std is not None: - assert abs(output.std() - target_std) < lim, output.std() - if target_mean is not None: - assert abs(output.mean() - target_mean) < lim, output.mean() - if target_max is not None: - assert abs(output.max() - target_max) < lim, output.max() - if target_min is not None: - assert abs(output.min() - target_min) < lim, output.min() - - class KerasInitializersTest(test.TestCase): + def _runner(self, init, shape, target_mean=None, target_std=None, + target_max=None, target_min=None): + variable = keras.backend.variable(init(shape)) + output = keras.backend.get_value(variable) + lim = 3e-2 + if target_std is not None: + self.assertGreater(lim, abs(output.std() - target_std)) + if target_mean is not None: + self.assertGreater(lim, abs(output.mean() - target_mean)) + if target_max is not None: + self.assertGreater(lim, abs(output.max() - target_max)) + if target_min is not None: + self.assertGreater(lim, abs(output.min() - target_min)) + + # Test serialization (assumes deterministic behavior). + config = init.get_config() + reconstructed_init = init.__class__.from_config(config) + variable = keras.backend.variable(reconstructed_init(shape)) + output_2 = keras.backend.get_value(variable) + self.assertAllClose(output, output_2, atol=1e-4) + def test_uniform(self): tensor_shape = (9, 6, 7) with self.test_session(): - _runner(keras.initializers.RandomUniform(minval=-1, maxval=1, seed=124), - tensor_shape, - target_mean=0., target_max=1, target_min=-1) + self._runner(keras.initializers.RandomUniform(minval=-1, + maxval=1, + seed=124), + tensor_shape, + target_mean=0., target_max=1, target_min=-1) def test_normal(self): tensor_shape = (8, 12, 99) with self.test_session(): - _runner(keras.initializers.RandomNormal(mean=0, stddev=1, seed=153), - tensor_shape, - target_mean=0., target_std=1) + self._runner(keras.initializers.RandomNormal(mean=0, stddev=1, seed=153), + tensor_shape, + target_mean=0., target_std=1) def test_truncated_normal(self): tensor_shape = (12, 99, 7) with self.test_session(): - _runner(keras.initializers.TruncatedNormal(mean=0, stddev=1, seed=126), - tensor_shape, - target_mean=0., target_std=None, target_max=2) + self._runner(keras.initializers.TruncatedNormal(mean=0, + stddev=1, + seed=126), + tensor_shape, + target_mean=0., target_std=None, target_max=2) def test_constant(self): tensor_shape = (5, 6, 4) with self.test_session(): - _runner(keras.initializers.Constant(2), tensor_shape, - target_mean=2, target_max=2, target_min=2) + self._runner(keras.initializers.Constant(2), tensor_shape, + target_mean=2, target_max=2, target_min=2) def test_lecun_uniform(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): - fan_in, _ = keras.initializers._compute_fans(tensor_shape) + fan_in, _ = init_ops._compute_fans(tensor_shape) scale = np.sqrt(3. / fan_in) - _runner(keras.initializers.lecun_uniform(seed=123), tensor_shape, - target_mean=0., target_max=scale, target_min=-scale) + self._runner(keras.initializers.lecun_uniform(seed=123), tensor_shape, + target_mean=0., target_max=scale, target_min=-scale) def test_glorot_uniform(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): - fan_in, fan_out = keras.initializers._compute_fans(tensor_shape) + fan_in, fan_out = init_ops._compute_fans(tensor_shape) scale = np.sqrt(6. / (fan_in + fan_out)) - _runner(keras.initializers.glorot_uniform(seed=123), tensor_shape, - target_mean=0., target_max=scale, target_min=-scale) + self._runner(keras.initializers.glorot_uniform(seed=123), tensor_shape, + target_mean=0., target_max=scale, target_min=-scale) def test_he_uniform(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): - fan_in, _ = keras.initializers._compute_fans(tensor_shape) + fan_in, _ = init_ops._compute_fans(tensor_shape) scale = np.sqrt(6. / fan_in) - _runner(keras.initializers.he_uniform(seed=123), tensor_shape, - target_mean=0., target_max=scale, target_min=-scale) + self._runner(keras.initializers.he_uniform(seed=123), tensor_shape, + target_mean=0., target_max=scale, target_min=-scale) def test_glorot_normal(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): - fan_in, fan_out = keras.initializers._compute_fans(tensor_shape) + fan_in, fan_out = init_ops._compute_fans(tensor_shape) scale = np.sqrt(2. / (fan_in + fan_out)) - _runner(keras.initializers.glorot_normal(seed=123), tensor_shape, - target_mean=0., target_std=None, target_max=2 * scale) + self._runner(keras.initializers.glorot_normal(seed=123), tensor_shape, + target_mean=0., target_std=None, target_max=2 * scale) def test_he_normal(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): - fan_in, _ = keras.initializers._compute_fans(tensor_shape) + fan_in, _ = init_ops._compute_fans(tensor_shape) scale = np.sqrt(2. / fan_in) - _runner(keras.initializers.he_normal(seed=123), tensor_shape, - target_mean=0., target_std=None, target_max=2 * scale) + self._runner(keras.initializers.he_normal(seed=123), tensor_shape, + target_mean=0., target_std=None, target_max=2 * scale) def test_orthogonal(self): - tensor_shape = (7, 8) + tensor_shape = (10, 10) with self.test_session(): - _runner(keras.initializers.orthogonal(seed=123), tensor_shape, - target_mean=0.) + self._runner(keras.initializers.orthogonal(seed=123), tensor_shape, + target_mean=0.) def test_identity(self): with self.test_session(): tensor_shape = (3, 4, 5) with self.assertRaises(ValueError): - _runner(keras.initializers.identity(), tensor_shape, - target_mean=1. / tensor_shape[0], target_max=1.) + self._runner(keras.initializers.identity(), tensor_shape, + target_mean=1. / tensor_shape[0], target_max=1.) tensor_shape = (3, 3) - _runner(keras.initializers.identity(), tensor_shape, - target_mean=1. / tensor_shape[0], target_max=1.) + self._runner(keras.initializers.identity(), tensor_shape, + target_mean=1. / tensor_shape[0], target_max=1.) def test_zero(self): tensor_shape = (4, 5) with self.test_session(): - _runner(keras.initializers.zeros(), tensor_shape, - target_mean=0., target_max=0.) + self._runner(keras.initializers.zeros(), tensor_shape, + target_mean=0., target_max=0.) def test_one(self): tensor_shape = (4, 5) with self.test_session(): - _runner(keras.initializers.ones(), tensor_shape, - target_mean=1., target_max=1.) + self._runner(keras.initializers.ones(), tensor_shape, + target_mean=1., target_max=1.) if __name__ == '__main__': diff --git a/tensorflow/contrib/keras/python/keras/integration_test.py b/tensorflow/contrib/keras/python/keras/integration_test.py index 3a3d36ca1c3..16d0713b31f 100644 --- a/tensorflow/contrib/keras/python/keras/integration_test.py +++ b/tensorflow/contrib/keras/python/keras/integration_test.py @@ -33,13 +33,13 @@ class KerasIntegrationTest(test.TestCase): (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( train_samples=200, test_samples=100, - input_shape=(8,), + input_shape=(10,), num_classes=2) y_train = keras.utils.to_categorical(y_train) y_test = keras.utils.to_categorical(y_test) model = keras.models.Sequential([ - keras.layers.Dense(8, + keras.layers.Dense(16, activation='relu', input_shape=x_train.shape[1:]), keras.layers.Dropout(0.1), @@ -59,13 +59,13 @@ class KerasIntegrationTest(test.TestCase): (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( train_samples=200, test_samples=100, - input_shape=(8,), + input_shape=(10,), num_classes=2) y_train = keras.utils.to_categorical(y_train) y_test = keras.utils.to_categorical(y_test) inputs = keras.layers.Input(shape=x_train.shape[1:]) - x = keras.layers.Dense(8, activation='relu')(inputs) + x = keras.layers.Dense(16, activation='relu')(inputs) x = keras.layers.Dropout(0.1)(x) outputs = keras.layers.Dense(y_train.shape[-1], activation='softmax')(x) diff --git a/tensorflow/contrib/keras/python/keras/optimizers_test.py b/tensorflow/contrib/keras/python/keras/optimizers_test.py index b3aaddb7c0c..af5e3c99b96 100644 --- a/tensorflow/contrib/keras/python/keras/optimizers_test.py +++ b/tensorflow/contrib/keras/python/keras/optimizers_test.py @@ -41,7 +41,7 @@ def _test_optimizer(optimizer, target=0.75): input_shape=(10,), num_classes=2) y_train = keras.utils.to_categorical(y_train) - model = _get_model(x_train.shape[1], 10, y_train.shape[1]) + model = _get_model(x_train.shape[1], 20, y_train.shape[1]) model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 60175965da8..67fff9c803b 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -49,30 +49,65 @@ class Initializer(object): def __call__(self, shape, dtype=None, partition_info=None): raise NotImplementedError + def get_config(self): + """Returns the configuration of the initializer as a JSON-serializable dict. + + Returns: + A JSON-serializable Python dict. + """ + return {} + + @classmethod + def from_config(cls, config): + """Instantiates an initializer from a configuration dictionary. + + Example: + + ``` + initializer = RandomUniform(-1, 1) + config = initializer.get_config() + initializer = RandomUniform.from_config(config) + ``` + + Arguments: + config: A Python dictionary. + It will typically be the output of `get_config`. + + Returns: + An Initializer instance. + """ + return cls(**config) + class Zeros(Initializer): """Initializer that generates tensors initialized to 0.""" def __init__(self, dtype=dtypes.float32): - self.dtype = dtype + self.dtype = dtypes.as_dtype(dtype) def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: dtype = self.dtype return array_ops.zeros(shape, dtype) + def get_config(self): + return {"dtype": self.dtype.name} + class Ones(Initializer): """Initializer that generates tensors initialized to 1.""" def __init__(self, dtype=dtypes.float32): - self.dtype = dtype + self.dtype = dtypes.as_dtype(dtype) def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: dtype = self.dtype return array_ops.ones(shape, dtype) + def get_config(self): + return {"dtype": self.dtype.name} + class Constant(Initializer): """Initializer that generates tensors with constant values. @@ -151,14 +186,27 @@ class Constant(Initializer): def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False): self.value = value - self.dtype = dtype - self.verify_shape = verify_shape + self.dtype = dtypes.as_dtype(dtype) + self._verify_shape = verify_shape - def __call__(self, shape, dtype=None, partition_info=None): + def __call__(self, shape, + dtype=None, + partition_info=None, + verify_shape=None): if dtype is None: dtype = self.dtype + if verify_shape is None: + verify_shape = self._verify_shape return constant_op.constant(self.value, dtype=dtype, shape=shape, - verify_shape=self.verify_shape) + verify_shape=verify_shape) + + def get_config(self): + # We don't include `verify_shape` for compatibility with Keras. + # `verify_shape` should be passed as an argument to `__call__` rather + # than as a constructor argument: conceptually it isn't a property + # of the initializer. + return {"value": self.value, + "dtype": self.dtype.name} class RandomUniform(Initializer): @@ -179,7 +227,7 @@ class RandomUniform(Initializer): self.minval = minval self.maxval = maxval self.seed = seed - self.dtype = dtype + self.dtype = dtypes.as_dtype(dtype) def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: @@ -187,6 +235,12 @@ class RandomUniform(Initializer): return random_ops.random_uniform(shape, self.minval, self.maxval, dtype, seed=self.seed) + def get_config(self): + return {"minval": self.minval, + "maxval": self.maxval, + "seed": self.seed, + "dtype": self.dtype.name} + class RandomNormal(Initializer): """Initializer that generates tensors with a normal distribution. @@ -206,7 +260,7 @@ class RandomNormal(Initializer): self.mean = mean self.stddev = stddev self.seed = seed - self.dtype = _assert_float_dtype(dtype) + self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: @@ -214,6 +268,12 @@ class RandomNormal(Initializer): return random_ops.random_normal(shape, self.mean, self.stddev, dtype, seed=self.seed) + def get_config(self): + return {"mean": self.mean, + "stddev": self.stddev, + "seed": self.seed, + "dtype": self.dtype.name} + class TruncatedNormal(Initializer): """Initializer that generates a truncated normal distribution. @@ -238,7 +298,7 @@ class TruncatedNormal(Initializer): self.mean = mean self.stddev = stddev self.seed = seed - self.dtype = _assert_float_dtype(dtype) + self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: @@ -246,6 +306,12 @@ class TruncatedNormal(Initializer): return random_ops.truncated_normal(shape, self.mean, self.stddev, dtype, seed=self.seed) + def get_config(self): + return {"mean": self.mean, + "stddev": self.stddev, + "seed": self.seed, + "dtype": self.dtype.name} + class UniformUnitScaling(Initializer): """Initializer that generates tensors without scaling variance. @@ -277,7 +343,7 @@ class UniformUnitScaling(Initializer): def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32): self.factor = factor self.seed = seed - self.dtype = _assert_float_dtype(dtype) + self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: @@ -298,6 +364,11 @@ class UniformUnitScaling(Initializer): return random_ops.random_uniform(shape, -max_val, max_val, dtype, seed=self.seed) + def get_config(self): + return {"factor": self.factor, + "seed": self.seed, + "dtype": self.dtype.name} + class VarianceScaling(Initializer): """Initializer capable of adapting its scale to the shape of weights tensors. @@ -342,7 +413,7 @@ class VarianceScaling(Initializer): self.mode = mode self.distribution = distribution self.seed = seed - self.dtype = _assert_float_dtype(dtype) + self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: @@ -367,6 +438,13 @@ class VarianceScaling(Initializer): return random_ops.random_uniform(shape, -limit, limit, dtype, seed=self.seed) + def get_config(self): + return {"scale": self.scale, + "mode": self.mode, + "distribution": self.distribution, + "seed": self.seed, + "dtype": self.dtype.name} + class Orthogonal(Initializer): """Initializer that generates an orthogonal matrix. @@ -388,9 +466,9 @@ class Orthogonal(Initializer): for behavior. """ - def __init__(self, gain=1.0, dtype=dtypes.float32, seed=None): + def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32): self.gain = gain - self.dtype = _assert_float_dtype(dtype) + self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) self.seed = seed def __call__(self, shape, dtype=None, partition_info=None): @@ -421,6 +499,11 @@ class Orthogonal(Initializer): q = array_ops.transpose(v) return self.gain * array_ops.reshape(q, shape) + def get_config(self): + return {"gain": self.gain, + "seed": self.seed, + "dtype": self.dtype.name} + # Aliases. diff --git a/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt index d34bfe51479..00ec669b168 100644 --- a/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt @@ -7,4 +7,12 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'False\'], " } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt index d84ddc6eb00..210b56242b2 100644 --- a/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt @@ -7,4 +7,12 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt index c8e266e70cf..13ec7454f41 100644 --- a/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt @@ -5,6 +5,14 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'gain\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \"\", \'None\'], " + argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } } diff --git a/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt index 70308bc6014..5993fdeb9c2 100644 --- a/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt @@ -7,4 +7,12 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt index 37bb1956e82..a434ed1599e 100644 --- a/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt @@ -7,4 +7,12 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt index 7c48f4af076..c1e1c230a9f 100644 --- a/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt @@ -7,4 +7,12 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt index 4558db619e8..e1b18dc92fb 100644 --- a/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt @@ -7,4 +7,12 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'factor\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt index 8313009a68c..e229b02ceec 100644 --- a/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt @@ -7,4 +7,12 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } }