Handle num_rows < num_cols case in orthogonal_initializer properly (as the original SVD-based implementation would) instead of just padding with zeros.

Run pyformat.

PiperOrigin-RevId: 163478869
This commit is contained in:
A. Unique TensorFlower 2017-07-28 09:33:43 -07:00 committed by TensorFlower Gardener
parent 046f912bb5
commit 7635e9db10

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations often used for initializing tensors.
All variable initializers returned by functions in this file should have the
@ -190,24 +189,20 @@ class Constant(Initializer):
self.dtype = dtypes.as_dtype(dtype)
self._verify_shape = verify_shape
def __call__(self, shape,
dtype=None,
partition_info=None,
verify_shape=None):
def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
if dtype is None:
dtype = self.dtype
if verify_shape is None:
verify_shape = self._verify_shape
return constant_op.constant(self.value, dtype=dtype, shape=shape,
verify_shape=verify_shape)
return constant_op.constant(
self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
def get_config(self):
# We don't include `verify_shape` for compatibility with Keras.
# `verify_shape` should be passed as an argument to `__call__` rather
# than as a constructor argument: conceptually it isn't a property
# of the initializer.
return {"value": self.value,
"dtype": self.dtype.name}
return {"value": self.value, "dtype": self.dtype.name}
class RandomUniform(Initializer):
@ -233,14 +228,16 @@ class RandomUniform(Initializer):
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return random_ops.random_uniform(shape, self.minval, self.maxval,
dtype, seed=self.seed)
return random_ops.random_uniform(
shape, self.minval, self.maxval, dtype, seed=self.seed)
def get_config(self):
return {"minval": self.minval,
"maxval": self.maxval,
"seed": self.seed,
"dtype": self.dtype.name}
return {
"minval": self.minval,
"maxval": self.maxval,
"seed": self.seed,
"dtype": self.dtype.name
}
class RandomNormal(Initializer):
@ -266,14 +263,16 @@ class RandomNormal(Initializer):
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return random_ops.random_normal(shape, self.mean, self.stddev,
dtype, seed=self.seed)
return random_ops.random_normal(
shape, self.mean, self.stddev, dtype, seed=self.seed)
def get_config(self):
return {"mean": self.mean,
"stddev": self.stddev,
"seed": self.seed,
"dtype": self.dtype.name}
return {
"mean": self.mean,
"stddev": self.stddev,
"seed": self.seed,
"dtype": self.dtype.name
}
class TruncatedNormal(Initializer):
@ -304,14 +303,16 @@ class TruncatedNormal(Initializer):
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
return random_ops.truncated_normal(shape, self.mean, self.stddev,
dtype, seed=self.seed)
return random_ops.truncated_normal(
shape, self.mean, self.stddev, dtype, seed=self.seed)
def get_config(self):
return {"mean": self.mean,
"stddev": self.stddev,
"seed": self.seed,
"dtype": self.dtype.name}
return {
"mean": self.mean,
"stddev": self.stddev,
"seed": self.seed,
"dtype": self.dtype.name
}
class UniformUnitScaling(Initializer):
@ -362,13 +363,11 @@ class UniformUnitScaling(Initializer):
# Avoid errors when initializing zero-size tensors.
input_size = max(input_size, 1.0)
max_val = math.sqrt(3 / input_size) * self.factor
return random_ops.random_uniform(shape, -max_val, max_val,
dtype, seed=self.seed)
return random_ops.random_uniform(
shape, -max_val, max_val, dtype, seed=self.seed)
def get_config(self):
return {"factor": self.factor,
"seed": self.seed,
"dtype": self.dtype.name}
return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
class VarianceScaling(Initializer):
@ -398,7 +397,8 @@ class VarianceScaling(Initializer):
"distribution" arguments.
"""
def __init__(self, scale=1.0,
def __init__(self,
scale=1.0,
mode="fan_in",
distribution="normal",
seed=None,
@ -432,27 +432,31 @@ class VarianceScaling(Initializer):
scale /= max(1., (fan_in + fan_out) / 2.)
if self.distribution == "normal":
stddev = math.sqrt(scale)
return random_ops.truncated_normal(shape, 0.0, stddev,
dtype, seed=self.seed)
return random_ops.truncated_normal(
shape, 0.0, stddev, dtype, seed=self.seed)
else:
limit = math.sqrt(3.0 * scale)
return random_ops.random_uniform(shape, -limit, limit,
dtype, seed=self.seed)
return random_ops.random_uniform(
shape, -limit, limit, dtype, seed=self.seed)
def get_config(self):
return {"scale": self.scale,
"mode": self.mode,
"distribution": self.distribution,
"seed": self.seed,
"dtype": self.dtype.name}
return {
"scale": self.scale,
"mode": self.mode,
"distribution": self.distribution,
"seed": self.seed,
"dtype": self.dtype.name
}
class Orthogonal(Initializer):
"""Initializer that generates an orthogonal matrix.
If the shape of the tensor to initialize is two-dimensional, i is initialized
with an orthogonal matrix obtained from the singular value decomposition of a
matrix of uniform random numbers.
with an orthogonal matrix obtained from the QR decomposition of a matrix of
uniform random numbers. If the matrix has fewer rows than columns then the
output will have orthogonal rows. Otherwise, the output will have orthogonal
columns.
If the shape of the tensor to initialize is more than two-dimensional,
a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
@ -485,27 +489,23 @@ class Orthogonal(Initializer):
for dim in shape[:-1]:
num_rows *= dim
num_cols = shape[-1]
flat_shape = (num_rows, num_cols)
flat_shape = (num_cols, num_rows) if num_rows < num_cols else (num_rows,
num_cols)
# Generate a random matrix
a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
# Compute the qr factorization
q, r = linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
square_len = math_ops.minimum(num_rows, num_cols)
d = array_ops.diag_part(r[:square_len, :square_len])
d = array_ops.diag_part(r)
ph = d / math_ops.abs(d)
q *= ph
# Pad zeros to Q (if rows smaller than cols)
if num_rows < num_cols:
padding = array_ops.zeros([num_rows, num_cols - num_rows], dtype=dtype)
q = array_ops.concat([q, padding], 1)
q = array_ops.matrix_transpose(q)
return self.gain * array_ops.reshape(q, shape)
def get_config(self):
return {"gain": self.gain,
"seed": self.seed,
"dtype": self.dtype.name}
return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
# Aliases.
@ -520,6 +520,7 @@ truncated_normal_initializer = TruncatedNormal
uniform_unit_scaling_initializer = UniformUnitScaling
variance_scaling_initializer = VarianceScaling
orthogonal_initializer = Orthogonal
# pylint: enable=invalid-name
@ -542,11 +543,8 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
Returns:
An initializer.
"""
return variance_scaling_initializer(scale=1.0,
mode="fan_avg",
distribution="uniform",
seed=seed,
dtype=dtype)
return variance_scaling_initializer(
scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
@ -568,11 +566,8 @@ def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
Returns:
An initializer.
"""
return variance_scaling_initializer(scale=1.0,
mode="fan_avg",
distribution="normal",
seed=seed,
dtype=dtype)
return variance_scaling_initializer(
scale=1.0, mode="fan_avg", distribution="normal", seed=seed, dtype=dtype)
# Utility functions.