Fork the implementation of tf.initializer to keras.
Since the implementation between tf.initializer and keras.initializer are just duplicates, copy the function to keras so that it is standalone. This allows us to freely delete the code in tf if preferred. Also update the build dependency to be more explicit. PiperOrigin-RevId: 349305978 Change-Id: Ic49a160037a5a0a77bc8826c597ffd7fbeaa5011
This commit is contained in:
parent
c06cd62954
commit
0817aac6be
@ -212,8 +212,18 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":backend",
|
||||
"//tensorflow/python:init_ops_v2",
|
||||
"//tensorflow/python/keras/utils:engine_utils",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:stateless_random_ops",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/keras/utils:generic_utils",
|
||||
"//tensorflow/python/keras/utils:tf_inspect",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_linalg_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
@ -120,7 +119,7 @@ class Initializer(object):
|
||||
|
||||
|
||||
@keras_export('keras.initializers.Zeros', 'keras.initializers.zeros', v1=[])
|
||||
class Zeros(init_ops_v2.Zeros, Initializer):
|
||||
class Zeros(Initializer):
|
||||
"""Initializer that generates tensors initialized to 0.
|
||||
|
||||
Also available via the shortcut function `tf.keras.initializers.zeros`.
|
||||
@ -147,11 +146,17 @@ class Zeros(init_ops_v2.Zeros, Initializer):
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
_validate_kwargs(self.__class__.__name__, kwargs)
|
||||
dtype = _get_dtype(dtype)
|
||||
if not dtype.is_numpy_compatible or dtype == dtypes.string:
|
||||
raise ValueError('Expected numeric or boolean dtype, got %s.' % dtype)
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
return array_ops.zeros(shape, dtype)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[])
|
||||
class Ones(init_ops_v2.Ones, Initializer):
|
||||
class Ones(Initializer):
|
||||
"""Initializer that generates tensors initialized to 1.
|
||||
|
||||
Also available via the shortcut function `tf.keras.initializers.ones`.
|
||||
@ -178,7 +183,13 @@ class Ones(init_ops_v2.Ones, Initializer):
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
_validate_kwargs(self.__class__.__name__, kwargs)
|
||||
dtype = _get_dtype(dtype)
|
||||
if not dtype.is_numpy_compatible or dtype == dtypes.string:
|
||||
raise ValueError('Expected numeric or boolean dtype, got %s.' % dtype)
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
return array_ops.ones(shape, dtype)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.Constant',
|
||||
@ -232,7 +243,7 @@ class Constant(Initializer):
|
||||
@keras_export('keras.initializers.RandomUniform',
|
||||
'keras.initializers.random_uniform',
|
||||
v1=[])
|
||||
class RandomUniform(init_ops_v2.RandomUniform, Initializer):
|
||||
class RandomUniform(Initializer):
|
||||
"""Initializer that generates tensors with a uniform distribution.
|
||||
|
||||
Also available via the shortcut function
|
||||
@ -257,6 +268,12 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
|
||||
always produce the same random tensor for a given shape and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, minval=-0.05, maxval=0.05, seed=None):
|
||||
self.minval = minval
|
||||
self.maxval = maxval
|
||||
self.seed = seed
|
||||
self._random_generator = _RandomGenerator(seed)
|
||||
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
@ -269,14 +286,27 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(RandomUniform, self).__call__(
|
||||
shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
_validate_kwargs(self.__class__.__name__, kwargs)
|
||||
dtype = _get_dtype(dtype)
|
||||
if not dtype.is_floating and not dtype.is_integer:
|
||||
raise ValueError('Expected float or integer dtype, got %s.' % dtype)
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
return self._random_generator.random_uniform(shape, self.minval,
|
||||
self.maxval, dtype)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'minval': self.minval,
|
||||
'maxval': self.maxval,
|
||||
'seed': self.seed
|
||||
}
|
||||
|
||||
|
||||
@keras_export('keras.initializers.RandomNormal',
|
||||
'keras.initializers.random_normal',
|
||||
v1=[])
|
||||
class RandomNormal(init_ops_v2.RandomNormal, Initializer):
|
||||
class RandomNormal(Initializer):
|
||||
"""Initializer that generates tensors with a normal distribution.
|
||||
|
||||
Also available via the shortcut function
|
||||
@ -301,6 +331,12 @@ class RandomNormal(init_ops_v2.RandomNormal, Initializer):
|
||||
always produce the same random tensor for a given shape and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, mean=0.0, stddev=0.05, seed=None):
|
||||
self.mean = mean
|
||||
self.stddev = stddev
|
||||
self.seed = seed
|
||||
self._random_generator = _RandomGenerator(seed)
|
||||
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized to random normal values.
|
||||
|
||||
@ -312,8 +348,19 @@ class RandomNormal(init_ops_v2.RandomNormal, Initializer):
|
||||
`tf.keras.backend.set_floatx(float_dtype)`)
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(RandomNormal, self).__call__(
|
||||
shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
_validate_kwargs(self.__class__.__name__, kwargs)
|
||||
dtype = _assert_float_dtype(_get_dtype(dtype))
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
return self._random_generator.random_normal(shape, self.mean, self.stddev,
|
||||
dtype)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'mean': self.mean,
|
||||
'stddev': self.stddev,
|
||||
'seed': self.seed
|
||||
}
|
||||
|
||||
|
||||
@keras_export('keras.initializers.TruncatedNormal',
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.initializers.Ones"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Ones\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Ones\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.initializers.RandomNormal"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.RandomNormal\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.RandomNormal\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.initializers.RandomUniform"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.RandomUniform\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.RandomUniform\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.initializers.Zeros"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Zeros\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Zeros\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.initializers.ones"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Ones\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Ones\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.initializers.random_normal"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.RandomNormal\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.RandomNormal\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.initializers.random_uniform"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.RandomUniform\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.RandomUniform\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.initializers.zeros"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Zeros\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Zeros\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.keras.initializers.Ones"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Ones\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Ones\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.keras.initializers.RandomNormal"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.RandomNormal\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.RandomNormal\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.keras.initializers.RandomUniform"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.RandomUniform\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.RandomUniform\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.keras.initializers.Zeros"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Zeros\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Zeros\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.keras.initializers.ones"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Ones\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Ones\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.keras.initializers.random_normal"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.RandomNormal\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.RandomNormal\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.keras.initializers.random_uniform"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.RandomUniform\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.RandomUniform\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
@ -1,8 +1,6 @@
|
||||
path: "tensorflow.keras.initializers.zeros"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Zeros\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Zeros\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.initializers.initializers_v2.Initializer\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
|
Loading…
Reference in New Issue
Block a user