Name change in LinearOperator: batch_shape_dynamic --> batch_shape_tensor.
Similarly for other "dynamic" Ops. Change: 144728885
This commit is contained in:
parent
3c44578744
commit
64ea20632b
@ -1977,7 +1977,7 @@ class AffineLinearOperator(Bijector):
|
||||
if scale.tensor_rank is not None:
|
||||
batch_ndims = scale.tensor_rank - 2
|
||||
else:
|
||||
batch_ndims = scale.tensor_rank_dynamic() - 2
|
||||
batch_ndims = scale.tensor_rank_tensor() - 2
|
||||
graph_parents += [batch_ndims]
|
||||
else:
|
||||
batch_ndims = 0 # We won't need shape inference when scale is None.
|
||||
|
@ -200,16 +200,16 @@ class NonSquareLinearOperatorCompositionTest(
|
||||
operator = linalg.LinearOperatorComposition(operators)
|
||||
self.assertAllEqual((2, 3, 5), operator.shape)
|
||||
|
||||
def test_dynamic_shapes_when_statically_available(self):
|
||||
def test_shape_tensors_when_statically_available(self):
|
||||
operators = [
|
||||
linalg.LinearOperatorMatrix(rng.rand(2, 3, 4)),
|
||||
linalg.LinearOperatorMatrix(rng.rand(2, 4, 5))
|
||||
]
|
||||
operator = linalg.LinearOperatorComposition(operators)
|
||||
with self.test_session():
|
||||
self.assertAllEqual((2, 3, 5), operator.shape_dynamic().eval())
|
||||
self.assertAllEqual((2, 3, 5), operator.shape_tensor().eval())
|
||||
|
||||
def test_dynamic_shapes_when_only_dynamically_available(self):
|
||||
def test_shape_tensors_when_only_dynamically_available(self):
|
||||
mat_1 = rng.rand(1, 2, 3, 4)
|
||||
mat_2 = rng.rand(1, 2, 4, 5)
|
||||
mat_ph_1 = array_ops.placeholder(dtypes.float64)
|
||||
@ -223,7 +223,7 @@ class NonSquareLinearOperatorCompositionTest(
|
||||
operator = linalg.LinearOperatorComposition(operators)
|
||||
with self.test_session():
|
||||
self.assertAllEqual(
|
||||
(1, 2, 3, 5), operator.shape_dynamic().eval(feed_dict=feed_dict))
|
||||
(1, 2, 3, 5), operator.shape_tensor().eval(feed_dict=feed_dict))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -31,7 +31,7 @@ rng = np.random.RandomState(123)
|
||||
|
||||
|
||||
class LinearOperatorShape(linalg.LinearOperator):
|
||||
"""LinearOperator that implements the methods ._shape and _shape_dynamic."""
|
||||
"""LinearOperator that implements the methods ._shape and _shape_tensor."""
|
||||
|
||||
def __init__(self,
|
||||
shape,
|
||||
@ -49,7 +49,7 @@ class LinearOperatorShape(linalg.LinearOperator):
|
||||
def _shape(self):
|
||||
return tensor_shape.TensorShape(self._stored_shape)
|
||||
|
||||
def _shape_dynamic(self):
|
||||
def _shape_tensor(self):
|
||||
return constant_op.constant(self._stored_shape, dtype=dtypes.int32)
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ class LinearOperatorApplyOnly(linalg.LinearOperator):
|
||||
def _shape(self):
|
||||
return self._matrix.get_shape()
|
||||
|
||||
def _shape_dynamic(self):
|
||||
def _shape_tensor(self):
|
||||
return array_ops.shape(self._matrix)
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
@ -96,11 +96,11 @@ class LinearOperatorTest(test.TestCase):
|
||||
shape = (1, 2, 3, 4)
|
||||
operator = LinearOperatorShape(shape)
|
||||
|
||||
self.assertAllEqual(shape, operator.shape_dynamic().eval())
|
||||
self.assertAllEqual(4, operator.tensor_rank_dynamic().eval())
|
||||
self.assertAllEqual((1, 2), operator.batch_shape_dynamic().eval())
|
||||
self.assertAllEqual(4, operator.domain_dimension_dynamic().eval())
|
||||
self.assertAllEqual(3, operator.range_dimension_dynamic().eval())
|
||||
self.assertAllEqual(shape, operator.shape_tensor().eval())
|
||||
self.assertAllEqual(4, operator.tensor_rank_tensor().eval())
|
||||
self.assertAllEqual((1, 2), operator.batch_shape_tensor().eval())
|
||||
self.assertAllEqual(4, operator.domain_dimension_tensor().eval())
|
||||
self.assertAllEqual(3, operator.range_dimension_tensor().eval())
|
||||
|
||||
def test_is_x_properties(self):
|
||||
operator = LinearOperatorShape(
|
||||
@ -120,7 +120,7 @@ class LinearOperatorTest(test.TestCase):
|
||||
self.assertAllEqual((2, 3, 4), operator_dense.get_shape())
|
||||
self.assertAllClose(matrix, operator_dense.eval())
|
||||
|
||||
def test_generic_to_dense_method_non_square_matrix_dynamic(self):
|
||||
def test_generic_to_dense_method_non_square_matrix_tensor(self):
|
||||
matrix = rng.randn(2, 3, 4)
|
||||
matrix_ph = array_ops.placeholder(dtypes.float64)
|
||||
operator = LinearOperatorApplyOnly(matrix_ph)
|
||||
|
@ -96,7 +96,7 @@ class DomainDimensionStubOperator(object):
|
||||
def __init__(self, domain_dimension):
|
||||
self._domain_dimension = ops.convert_to_tensor(domain_dimension)
|
||||
|
||||
def domain_dimension_dynamic(self):
|
||||
def domain_dimension_tensor(self):
|
||||
return self._domain_dimension
|
||||
|
||||
|
||||
|
@ -180,13 +180,15 @@ class LinearOperator(object):
|
||||
self._is_positive_definite = is_positive_definite
|
||||
self._name = name or type(self).__name__
|
||||
|
||||
# We will cache some values to avoid repeatedly adding shape
|
||||
# manipulation ops to the graph. Cleaner.
|
||||
self._cached_shape_dynamic = None
|
||||
self._cached_batch_shape_dynamic = None
|
||||
self._cached_domain_dimension_dynamic = None
|
||||
self._cached_range_dimension_dynamic = None
|
||||
self._cached_tensor_rank_dynamic = None
|
||||
# We will cache some tensors to avoid repeatedly adding shape
|
||||
# manipulation ops to the graph.
|
||||
# Naming convention:
|
||||
# self._cached_X_tensor is the cached version of self._X_tensor.
|
||||
self._cached_shape_tensor = None
|
||||
self._cached_batch_shape_tensor = None
|
||||
self._cached_domain_dimension_tensor = None
|
||||
self._cached_range_dimension_tensor = None
|
||||
self._cached_tensor_rank_tensor = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _name_scope(self, name=None, values=None):
|
||||
@ -240,10 +242,10 @@ class LinearOperator(object):
|
||||
"""
|
||||
return self._shape()
|
||||
|
||||
def _shape_dynamic(self):
|
||||
raise NotImplementedError("_shape_dynamic is not implemented.")
|
||||
def _shape_tensor(self):
|
||||
raise NotImplementedError("_shape_tensor is not implemented.")
|
||||
|
||||
def shape_dynamic(self, name="shape_dynamic"):
|
||||
def shape_tensor(self, name="shape_tensor"):
|
||||
"""Shape of this `LinearOperator`, determined at runtime.
|
||||
|
||||
If this operator acts like the batch matrix `A` with
|
||||
@ -258,14 +260,14 @@ class LinearOperator(object):
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
# Be clean by avoiding adding shape Ops to the graph too many times.
|
||||
if self._cached_shape_dynamic is None:
|
||||
if self._cached_shape_tensor is None:
|
||||
# Prefer to use statically defined shape if available.
|
||||
if self.shape.is_fully_defined():
|
||||
self._cached_shape_dynamic = linear_operator_util.shape_tensor(
|
||||
self._cached_shape_tensor = linear_operator_util.shape_tensor(
|
||||
self.shape.as_list())
|
||||
else:
|
||||
self._cached_shape_dynamic = self._shape_dynamic()
|
||||
return self._cached_shape_dynamic
|
||||
self._cached_shape_tensor = self._shape_tensor()
|
||||
return self._cached_shape_tensor
|
||||
|
||||
@property
|
||||
def batch_shape(self):
|
||||
@ -281,7 +283,7 @@ class LinearOperator(object):
|
||||
# Derived classes get this "for free" once .shape is implemented.
|
||||
return self.shape[:-2]
|
||||
|
||||
def batch_shape_dynamic(self, name="batch_shape_dynamic"):
|
||||
def batch_shape_tensor(self, name="batch_shape_tensor"):
|
||||
"""Shape of batch dimensions of this operator, determined at runtime.
|
||||
|
||||
If this operator acts like the batch matrix `A` with
|
||||
@ -296,14 +298,14 @@ class LinearOperator(object):
|
||||
"""
|
||||
# Derived classes get this "for free" once .shape() is implemented.
|
||||
with self._name_scope(name):
|
||||
if self._cached_batch_shape_dynamic is None:
|
||||
if self._cached_batch_shape_tensor is None:
|
||||
# Prefer to use statically defined shape if available.
|
||||
if self.batch_shape.is_fully_defined():
|
||||
self._cached_batch_shape_dynamic = linear_operator_util.shape_tensor(
|
||||
self._cached_batch_shape_tensor = linear_operator_util.shape_tensor(
|
||||
self.batch_shape.as_list(), name="batch_shape")
|
||||
else:
|
||||
self._cached_batch_shape_dynamic = self.shape_dynamic()[:-2]
|
||||
return self._cached_batch_shape_dynamic
|
||||
self._cached_batch_shape_tensor = self.shape_tensor()[:-2]
|
||||
return self._cached_batch_shape_tensor
|
||||
|
||||
@property
|
||||
def tensor_rank(self, name="tensor_rank"):
|
||||
@ -322,7 +324,7 @@ class LinearOperator(object):
|
||||
with self._name_scope(name):
|
||||
return self.shape.ndims
|
||||
|
||||
def tensor_rank_dynamic(self, name="tensor_rank_dynamic"):
|
||||
def tensor_rank_tensor(self, name="tensor_rank_tensor"):
|
||||
"""Rank (in the sense of tensors) of matrix corresponding to this operator.
|
||||
|
||||
If this operator acts like the batch matrix `A` with
|
||||
@ -336,15 +338,15 @@ class LinearOperator(object):
|
||||
"""
|
||||
# Derived classes get this "for free" once .shape() is implemented.
|
||||
with self._name_scope(name):
|
||||
if self._cached_tensor_rank_dynamic is None:
|
||||
if self._cached_tensor_rank_tensor is None:
|
||||
# Prefer to use statically defined shape if available.
|
||||
if self.tensor_rank is not None:
|
||||
self._cached_tensor_rank_dynamic = ops.convert_to_tensor(
|
||||
self._cached_tensor_rank_tensor = ops.convert_to_tensor(
|
||||
self.tensor_rank)
|
||||
else:
|
||||
self._cached_tensor_rank_dynamic = array_ops.size(
|
||||
self.shape_dynamic())
|
||||
return self._cached_tensor_rank_dynamic
|
||||
self._cached_tensor_rank_tensor = array_ops.size(
|
||||
self.shape_tensor())
|
||||
return self._cached_tensor_rank_tensor
|
||||
|
||||
@property
|
||||
def domain_dimension(self):
|
||||
@ -359,7 +361,7 @@ class LinearOperator(object):
|
||||
# Derived classes get this "for free" once .shape is implemented.
|
||||
return self.shape[-1]
|
||||
|
||||
def domain_dimension_dynamic(self, name="domain_dimension_dynamic"):
|
||||
def domain_dimension_tensor(self, name="domain_dimension_tensor"):
|
||||
"""Dimension (in the sense of vector spaces) of the domain of this operator.
|
||||
|
||||
Determined at runtime.
|
||||
@ -375,14 +377,14 @@ class LinearOperator(object):
|
||||
"""
|
||||
# Derived classes get this "for free" once .shape() is implemented.
|
||||
with self._name_scope(name):
|
||||
if self._cached_domain_dimension_dynamic is None:
|
||||
if self._cached_domain_dimension_tensor is None:
|
||||
# Prefer to use statically defined shape if available.
|
||||
if self.domain_dimension.value is not None:
|
||||
self._cached_domain_dimension_dynamic = ops.convert_to_tensor(
|
||||
self._cached_domain_dimension_tensor = ops.convert_to_tensor(
|
||||
self.domain_dimension.value)
|
||||
else:
|
||||
self._cached_domain_dimension_dynamic = self.shape_dynamic()[-1]
|
||||
return self._cached_domain_dimension_dynamic
|
||||
self._cached_domain_dimension_tensor = self.shape_tensor()[-1]
|
||||
return self._cached_domain_dimension_tensor
|
||||
|
||||
@property
|
||||
def range_dimension(self):
|
||||
@ -397,7 +399,7 @@ class LinearOperator(object):
|
||||
# Derived classes get this "for free" once .shape is implemented.
|
||||
return self.shape[-2]
|
||||
|
||||
def range_dimension_dynamic(self, name="range_dimension_dynamic"):
|
||||
def range_dimension_tensor(self, name="range_dimension_tensor"):
|
||||
"""Dimension (in the sense of vector spaces) of the range of this operator.
|
||||
|
||||
Determined at runtime.
|
||||
@ -413,14 +415,14 @@ class LinearOperator(object):
|
||||
"""
|
||||
# Derived classes get this "for free" once .shape() is implemented.
|
||||
with self._name_scope(name):
|
||||
if self._cached_range_dimension_dynamic is None:
|
||||
if self._cached_range_dimension_tensor is None:
|
||||
# Prefer to use statically defined shape if available.
|
||||
if self.range_dimension.value is not None:
|
||||
self._cached_range_dimension_dynamic = ops.convert_to_tensor(
|
||||
self._cached_range_dimension_tensor = ops.convert_to_tensor(
|
||||
self.range_dimension.value)
|
||||
else:
|
||||
self._cached_range_dimension_dynamic = self.shape_dynamic()[-2]
|
||||
return self._cached_range_dimension_dynamic
|
||||
self._cached_range_dimension_tensor = self.shape_tensor()[-2]
|
||||
return self._cached_range_dimension_tensor
|
||||
|
||||
def _assert_non_singular(self):
|
||||
raise NotImplementedError("assert_non_singular is not implemented.")
|
||||
@ -574,12 +576,12 @@ class LinearOperator(object):
|
||||
if self.batch_shape.is_fully_defined():
|
||||
batch_shape = self.batch_shape
|
||||
else:
|
||||
batch_shape = self.batch_shape_dynamic()
|
||||
batch_shape = self.batch_shape_tensor()
|
||||
|
||||
if self.domain_dimension.value is not None:
|
||||
n = self.domain_dimension.value
|
||||
else:
|
||||
n = self.domain_dimension_dynamic()
|
||||
n = self.domain_dimension_tensor()
|
||||
|
||||
eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
|
||||
return self.apply(eye)
|
||||
|
@ -202,7 +202,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
|
||||
return batch_shape.concatenate(matrix_shape)
|
||||
|
||||
def _shape_dynamic(self):
|
||||
def _shape_tensor(self):
|
||||
# Avoid messy broadcasting if possible.
|
||||
if self.shape.is_fully_defined():
|
||||
return ops.convert_to_tensor(
|
||||
@ -212,14 +212,14 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
# the graph. Things will fail at runtime naturally if shapes are
|
||||
# incompatible.
|
||||
matrix_shape = array_ops.stack([
|
||||
self.operators[0].range_dimension_dynamic(),
|
||||
self.operators[-1].domain_dimension_dynamic()
|
||||
self.operators[0].range_dimension_tensor(),
|
||||
self.operators[-1].domain_dimension_tensor()
|
||||
])
|
||||
|
||||
# Dummy Tensor of zeros. Will never be materialized.
|
||||
zeros = array_ops.zeros(shape=self.operators[0].batch_shape_dynamic())
|
||||
zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
|
||||
for operator in self.operators[1:]:
|
||||
zeros += array_ops.zeros(shape=operator.batch_shape_dynamic())
|
||||
zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
|
||||
batch_shape = array_ops.shape(zeros)
|
||||
|
||||
return array_ops.concat((batch_shape, matrix_shape), 0)
|
||||
|
@ -166,7 +166,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
d_shape = self._diag.get_shape()
|
||||
return d_shape.concatenate(d_shape[-1:])
|
||||
|
||||
def _shape_dynamic(self):
|
||||
def _shape_tensor(self):
|
||||
d_shape = array_ops.shape(self._diag)
|
||||
k = d_shape[-1]
|
||||
return array_ops.concat((d_shape, [k]), 0)
|
||||
|
@ -261,7 +261,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
|
||||
return batch_shape.concatenate(matrix_shape)
|
||||
|
||||
def _shape_dynamic(self):
|
||||
def _shape_tensor(self):
|
||||
matrix_shape = array_ops.stack(
|
||||
(self._num_rows, self._num_rows), axis=0)
|
||||
if self._batch_shape_arg is None:
|
||||
@ -307,7 +307,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
# Dynamic broadcast:
|
||||
# Always add to an array of zeros, rather than using a "cond", since a
|
||||
# cond would require copying data from GPU --> CPU.
|
||||
special_shape = array_ops.concat((self.batch_shape_dynamic(), [1, 1]), 0)
|
||||
special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0)
|
||||
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
|
||||
return x + zeros
|
||||
|
||||
@ -320,10 +320,10 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
return self._possibly_broadcast_batch_shape(x)
|
||||
|
||||
def _determinant(self):
|
||||
return array_ops.ones(shape=self.batch_shape_dynamic(), dtype=self.dtype)
|
||||
return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype)
|
||||
|
||||
def _log_abs_determinant(self):
|
||||
return array_ops.zeros(shape=self.batch_shape_dynamic(), dtype=self.dtype)
|
||||
return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
|
||||
|
||||
def _solve(self, rhs, adjoint=False):
|
||||
return self._apply(rhs)
|
||||
@ -566,7 +566,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
batch_shape = self.multiplier.get_shape()
|
||||
return batch_shape.concatenate(matrix_shape)
|
||||
|
||||
def _shape_dynamic(self):
|
||||
def _shape_tensor(self):
|
||||
matrix_shape = array_ops.stack(
|
||||
(self._num_rows, self._num_rows), axis=0)
|
||||
|
||||
|
@ -157,7 +157,7 @@ class LinearOperatorMatrix(linear_operator.LinearOperator):
|
||||
def _shape(self):
|
||||
return self._matrix.get_shape()
|
||||
|
||||
def _shape_dynamic(self):
|
||||
def _shape_tensor(self):
|
||||
return array_ops.shape(self._matrix)
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
|
@ -262,8 +262,8 @@ class SquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
|
||||
n = operator.domain_dimension.value
|
||||
x_shape = batch_shape + [n, r]
|
||||
else:
|
||||
batch_shape = operator.batch_shape_dynamic()
|
||||
n = operator.domain_dimension_dynamic()
|
||||
batch_shape = operator.batch_shape_tensor()
|
||||
n = operator.domain_dimension_tensor()
|
||||
x_shape = array_ops.concat((batch_shape, [n, r]), 0)
|
||||
|
||||
return random_normal(x_shape, dtype=operator.dtype)
|
||||
@ -316,11 +316,11 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
|
||||
n = operator.domain_dimension.value
|
||||
x_shape = batch_shape + [n, r]
|
||||
else:
|
||||
batch_shape = operator.batch_shape_dynamic()
|
||||
batch_shape = operator.batch_shape_tensor()
|
||||
if adjoint:
|
||||
n = operator.range_dimension_dynamic()
|
||||
n = operator.range_dimension_tensor()
|
||||
else:
|
||||
n = operator.domain_dimension_dynamic()
|
||||
n = operator.domain_dimension_tensor()
|
||||
x_shape = array_ops.concat((batch_shape, [n, r]), 0)
|
||||
|
||||
return random_normal(x_shape, dtype=operator.dtype)
|
||||
|
@ -157,7 +157,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
def _shape(self):
|
||||
return self._tril.get_shape()
|
||||
|
||||
def _shape_dynamic(self):
|
||||
def _shape_tensor(self):
|
||||
return array_ops.shape(self._tril)
|
||||
|
||||
def _assert_non_singular(self):
|
||||
|
@ -83,10 +83,10 @@ def assert_compatible_matrix_dimensions(operator, x):
|
||||
Returns:
|
||||
`Assert` `Op`.
|
||||
"""
|
||||
# Static checks are done in the base class. Only dynamic asserts here.
|
||||
# Static checks are done in the base class. Only tensor asserts here.
|
||||
assert_same_dd = check_ops.assert_equal(
|
||||
array_ops.shape(x)[-2],
|
||||
operator.domain_dimension_dynamic(),
|
||||
operator.domain_dimension_tensor(),
|
||||
message=(
|
||||
"Incompatible matrix dimensions. "
|
||||
"shape[-2] of argument to be the same as this operator"))
|
||||
|
Loading…
Reference in New Issue
Block a user