Handle num_rows < num_cols case in orthogonal_initializer properly (as the original SVD-based implementation would) instead of just padding with zeros.
Run pyformat. PiperOrigin-RevId: 163478869
This commit is contained in:
parent
046f912bb5
commit
7635e9db10
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Operations often used for initializing tensors.
|
||||
|
||||
All variable initializers returned by functions in this file should have the
|
||||
@ -190,24 +189,20 @@ class Constant(Initializer):
|
||||
self.dtype = dtypes.as_dtype(dtype)
|
||||
self._verify_shape = verify_shape
|
||||
|
||||
def __call__(self, shape,
|
||||
dtype=None,
|
||||
partition_info=None,
|
||||
verify_shape=None):
|
||||
def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
if verify_shape is None:
|
||||
verify_shape = self._verify_shape
|
||||
return constant_op.constant(self.value, dtype=dtype, shape=shape,
|
||||
verify_shape=verify_shape)
|
||||
return constant_op.constant(
|
||||
self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
|
||||
|
||||
def get_config(self):
|
||||
# We don't include `verify_shape` for compatibility with Keras.
|
||||
# `verify_shape` should be passed as an argument to `__call__` rather
|
||||
# than as a constructor argument: conceptually it isn't a property
|
||||
# of the initializer.
|
||||
return {"value": self.value,
|
||||
"dtype": self.dtype.name}
|
||||
return {"value": self.value, "dtype": self.dtype.name}
|
||||
|
||||
|
||||
class RandomUniform(Initializer):
|
||||
@ -233,14 +228,16 @@ class RandomUniform(Initializer):
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
return random_ops.random_uniform(shape, self.minval, self.maxval,
|
||||
dtype, seed=self.seed)
|
||||
return random_ops.random_uniform(
|
||||
shape, self.minval, self.maxval, dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"minval": self.minval,
|
||||
"maxval": self.maxval,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
return {
|
||||
"minval": self.minval,
|
||||
"maxval": self.maxval,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name
|
||||
}
|
||||
|
||||
|
||||
class RandomNormal(Initializer):
|
||||
@ -266,14 +263,16 @@ class RandomNormal(Initializer):
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
return random_ops.random_normal(shape, self.mean, self.stddev,
|
||||
dtype, seed=self.seed)
|
||||
return random_ops.random_normal(
|
||||
shape, self.mean, self.stddev, dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"mean": self.mean,
|
||||
"stddev": self.stddev,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
return {
|
||||
"mean": self.mean,
|
||||
"stddev": self.stddev,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name
|
||||
}
|
||||
|
||||
|
||||
class TruncatedNormal(Initializer):
|
||||
@ -304,14 +303,16 @@ class TruncatedNormal(Initializer):
|
||||
def __call__(self, shape, dtype=None, partition_info=None):
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
return random_ops.truncated_normal(shape, self.mean, self.stddev,
|
||||
dtype, seed=self.seed)
|
||||
return random_ops.truncated_normal(
|
||||
shape, self.mean, self.stddev, dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"mean": self.mean,
|
||||
"stddev": self.stddev,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
return {
|
||||
"mean": self.mean,
|
||||
"stddev": self.stddev,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name
|
||||
}
|
||||
|
||||
|
||||
class UniformUnitScaling(Initializer):
|
||||
@ -362,13 +363,11 @@ class UniformUnitScaling(Initializer):
|
||||
# Avoid errors when initializing zero-size tensors.
|
||||
input_size = max(input_size, 1.0)
|
||||
max_val = math.sqrt(3 / input_size) * self.factor
|
||||
return random_ops.random_uniform(shape, -max_val, max_val,
|
||||
dtype, seed=self.seed)
|
||||
return random_ops.random_uniform(
|
||||
shape, -max_val, max_val, dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"factor": self.factor,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
|
||||
|
||||
|
||||
class VarianceScaling(Initializer):
|
||||
@ -398,7 +397,8 @@ class VarianceScaling(Initializer):
|
||||
"distribution" arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, scale=1.0,
|
||||
def __init__(self,
|
||||
scale=1.0,
|
||||
mode="fan_in",
|
||||
distribution="normal",
|
||||
seed=None,
|
||||
@ -432,27 +432,31 @@ class VarianceScaling(Initializer):
|
||||
scale /= max(1., (fan_in + fan_out) / 2.)
|
||||
if self.distribution == "normal":
|
||||
stddev = math.sqrt(scale)
|
||||
return random_ops.truncated_normal(shape, 0.0, stddev,
|
||||
dtype, seed=self.seed)
|
||||
return random_ops.truncated_normal(
|
||||
shape, 0.0, stddev, dtype, seed=self.seed)
|
||||
else:
|
||||
limit = math.sqrt(3.0 * scale)
|
||||
return random_ops.random_uniform(shape, -limit, limit,
|
||||
dtype, seed=self.seed)
|
||||
return random_ops.random_uniform(
|
||||
shape, -limit, limit, dtype, seed=self.seed)
|
||||
|
||||
def get_config(self):
|
||||
return {"scale": self.scale,
|
||||
"mode": self.mode,
|
||||
"distribution": self.distribution,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
return {
|
||||
"scale": self.scale,
|
||||
"mode": self.mode,
|
||||
"distribution": self.distribution,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name
|
||||
}
|
||||
|
||||
|
||||
class Orthogonal(Initializer):
|
||||
"""Initializer that generates an orthogonal matrix.
|
||||
|
||||
If the shape of the tensor to initialize is two-dimensional, i is initialized
|
||||
with an orthogonal matrix obtained from the singular value decomposition of a
|
||||
matrix of uniform random numbers.
|
||||
with an orthogonal matrix obtained from the QR decomposition of a matrix of
|
||||
uniform random numbers. If the matrix has fewer rows than columns then the
|
||||
output will have orthogonal rows. Otherwise, the output will have orthogonal
|
||||
columns.
|
||||
|
||||
If the shape of the tensor to initialize is more than two-dimensional,
|
||||
a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
|
||||
@ -485,27 +489,23 @@ class Orthogonal(Initializer):
|
||||
for dim in shape[:-1]:
|
||||
num_rows *= dim
|
||||
num_cols = shape[-1]
|
||||
flat_shape = (num_rows, num_cols)
|
||||
flat_shape = (num_cols, num_rows) if num_rows < num_cols else (num_rows,
|
||||
num_cols)
|
||||
|
||||
# Generate a random matrix
|
||||
a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
|
||||
# Compute the qr factorization
|
||||
q, r = linalg_ops.qr(a, full_matrices=False)
|
||||
# Make Q uniform
|
||||
square_len = math_ops.minimum(num_rows, num_cols)
|
||||
d = array_ops.diag_part(r[:square_len, :square_len])
|
||||
d = array_ops.diag_part(r)
|
||||
ph = d / math_ops.abs(d)
|
||||
q *= ph
|
||||
# Pad zeros to Q (if rows smaller than cols)
|
||||
if num_rows < num_cols:
|
||||
padding = array_ops.zeros([num_rows, num_cols - num_rows], dtype=dtype)
|
||||
q = array_ops.concat([q, padding], 1)
|
||||
q = array_ops.matrix_transpose(q)
|
||||
return self.gain * array_ops.reshape(q, shape)
|
||||
|
||||
def get_config(self):
|
||||
return {"gain": self.gain,
|
||||
"seed": self.seed,
|
||||
"dtype": self.dtype.name}
|
||||
return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
|
||||
|
||||
|
||||
# Aliases.
|
||||
@ -520,6 +520,7 @@ truncated_normal_initializer = TruncatedNormal
|
||||
uniform_unit_scaling_initializer = UniformUnitScaling
|
||||
variance_scaling_initializer = VarianceScaling
|
||||
orthogonal_initializer = Orthogonal
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
@ -542,11 +543,8 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
|
||||
Returns:
|
||||
An initializer.
|
||||
"""
|
||||
return variance_scaling_initializer(scale=1.0,
|
||||
mode="fan_avg",
|
||||
distribution="uniform",
|
||||
seed=seed,
|
||||
dtype=dtype)
|
||||
return variance_scaling_initializer(
|
||||
scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
|
||||
|
||||
|
||||
def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
|
||||
@ -568,11 +566,8 @@ def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
|
||||
Returns:
|
||||
An initializer.
|
||||
"""
|
||||
return variance_scaling_initializer(scale=1.0,
|
||||
mode="fan_avg",
|
||||
distribution="normal",
|
||||
seed=seed,
|
||||
dtype=dtype)
|
||||
return variance_scaling_initializer(
|
||||
scale=1.0, mode="fan_avg", distribution="normal", seed=seed, dtype=dtype)
|
||||
|
||||
|
||||
# Utility functions.
|
||||
|
Loading…
x
Reference in New Issue
Block a user