Name change in LinearOperator: batch_shape_dynamic --> batch_shape_tensor.

Similarly for other "dynamic" Ops.
Change: 144728885
This commit is contained in:
Ian Langmore 2017-01-17 10:52:32 -08:00 committed by TensorFlower Gardener
parent 3c44578744
commit 64ea20632b
12 changed files with 74 additions and 72 deletions

View File

@ -1977,7 +1977,7 @@ class AffineLinearOperator(Bijector):
if scale.tensor_rank is not None: if scale.tensor_rank is not None:
batch_ndims = scale.tensor_rank - 2 batch_ndims = scale.tensor_rank - 2
else: else:
batch_ndims = scale.tensor_rank_dynamic() - 2 batch_ndims = scale.tensor_rank_tensor() - 2
graph_parents += [batch_ndims] graph_parents += [batch_ndims]
else: else:
batch_ndims = 0 # We won't need shape inference when scale is None. batch_ndims = 0 # We won't need shape inference when scale is None.

View File

@ -200,16 +200,16 @@ class NonSquareLinearOperatorCompositionTest(
operator = linalg.LinearOperatorComposition(operators) operator = linalg.LinearOperatorComposition(operators)
self.assertAllEqual((2, 3, 5), operator.shape) self.assertAllEqual((2, 3, 5), operator.shape)
def test_dynamic_shapes_when_statically_available(self): def test_shape_tensors_when_statically_available(self):
operators = [ operators = [
linalg.LinearOperatorMatrix(rng.rand(2, 3, 4)), linalg.LinearOperatorMatrix(rng.rand(2, 3, 4)),
linalg.LinearOperatorMatrix(rng.rand(2, 4, 5)) linalg.LinearOperatorMatrix(rng.rand(2, 4, 5))
] ]
operator = linalg.LinearOperatorComposition(operators) operator = linalg.LinearOperatorComposition(operators)
with self.test_session(): with self.test_session():
self.assertAllEqual((2, 3, 5), operator.shape_dynamic().eval()) self.assertAllEqual((2, 3, 5), operator.shape_tensor().eval())
def test_dynamic_shapes_when_only_dynamically_available(self): def test_shape_tensors_when_only_dynamically_available(self):
mat_1 = rng.rand(1, 2, 3, 4) mat_1 = rng.rand(1, 2, 3, 4)
mat_2 = rng.rand(1, 2, 4, 5) mat_2 = rng.rand(1, 2, 4, 5)
mat_ph_1 = array_ops.placeholder(dtypes.float64) mat_ph_1 = array_ops.placeholder(dtypes.float64)
@ -223,7 +223,7 @@ class NonSquareLinearOperatorCompositionTest(
operator = linalg.LinearOperatorComposition(operators) operator = linalg.LinearOperatorComposition(operators)
with self.test_session(): with self.test_session():
self.assertAllEqual( self.assertAllEqual(
(1, 2, 3, 5), operator.shape_dynamic().eval(feed_dict=feed_dict)) (1, 2, 3, 5), operator.shape_tensor().eval(feed_dict=feed_dict))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -31,7 +31,7 @@ rng = np.random.RandomState(123)
class LinearOperatorShape(linalg.LinearOperator): class LinearOperatorShape(linalg.LinearOperator):
"""LinearOperator that implements the methods ._shape and _shape_dynamic.""" """LinearOperator that implements the methods ._shape and _shape_tensor."""
def __init__(self, def __init__(self,
shape, shape,
@ -49,7 +49,7 @@ class LinearOperatorShape(linalg.LinearOperator):
def _shape(self): def _shape(self):
return tensor_shape.TensorShape(self._stored_shape) return tensor_shape.TensorShape(self._stored_shape)
def _shape_dynamic(self): def _shape_tensor(self):
return constant_op.constant(self._stored_shape, dtype=dtypes.int32) return constant_op.constant(self._stored_shape, dtype=dtypes.int32)
@ -71,7 +71,7 @@ class LinearOperatorApplyOnly(linalg.LinearOperator):
def _shape(self): def _shape(self):
return self._matrix.get_shape() return self._matrix.get_shape()
def _shape_dynamic(self): def _shape_tensor(self):
return array_ops.shape(self._matrix) return array_ops.shape(self._matrix)
def _apply(self, x, adjoint=False): def _apply(self, x, adjoint=False):
@ -96,11 +96,11 @@ class LinearOperatorTest(test.TestCase):
shape = (1, 2, 3, 4) shape = (1, 2, 3, 4)
operator = LinearOperatorShape(shape) operator = LinearOperatorShape(shape)
self.assertAllEqual(shape, operator.shape_dynamic().eval()) self.assertAllEqual(shape, operator.shape_tensor().eval())
self.assertAllEqual(4, operator.tensor_rank_dynamic().eval()) self.assertAllEqual(4, operator.tensor_rank_tensor().eval())
self.assertAllEqual((1, 2), operator.batch_shape_dynamic().eval()) self.assertAllEqual((1, 2), operator.batch_shape_tensor().eval())
self.assertAllEqual(4, operator.domain_dimension_dynamic().eval()) self.assertAllEqual(4, operator.domain_dimension_tensor().eval())
self.assertAllEqual(3, operator.range_dimension_dynamic().eval()) self.assertAllEqual(3, operator.range_dimension_tensor().eval())
def test_is_x_properties(self): def test_is_x_properties(self):
operator = LinearOperatorShape( operator = LinearOperatorShape(
@ -120,7 +120,7 @@ class LinearOperatorTest(test.TestCase):
self.assertAllEqual((2, 3, 4), operator_dense.get_shape()) self.assertAllEqual((2, 3, 4), operator_dense.get_shape())
self.assertAllClose(matrix, operator_dense.eval()) self.assertAllClose(matrix, operator_dense.eval())
def test_generic_to_dense_method_non_square_matrix_dynamic(self): def test_generic_to_dense_method_non_square_matrix_tensor(self):
matrix = rng.randn(2, 3, 4) matrix = rng.randn(2, 3, 4)
matrix_ph = array_ops.placeholder(dtypes.float64) matrix_ph = array_ops.placeholder(dtypes.float64)
operator = LinearOperatorApplyOnly(matrix_ph) operator = LinearOperatorApplyOnly(matrix_ph)

View File

@ -96,7 +96,7 @@ class DomainDimensionStubOperator(object):
def __init__(self, domain_dimension): def __init__(self, domain_dimension):
self._domain_dimension = ops.convert_to_tensor(domain_dimension) self._domain_dimension = ops.convert_to_tensor(domain_dimension)
def domain_dimension_dynamic(self): def domain_dimension_tensor(self):
return self._domain_dimension return self._domain_dimension

View File

@ -180,13 +180,15 @@ class LinearOperator(object):
self._is_positive_definite = is_positive_definite self._is_positive_definite = is_positive_definite
self._name = name or type(self).__name__ self._name = name or type(self).__name__
# We will cache some values to avoid repeatedly adding shape # We will cache some tensors to avoid repeatedly adding shape
# manipulation ops to the graph. Cleaner. # manipulation ops to the graph.
self._cached_shape_dynamic = None # Naming convention:
self._cached_batch_shape_dynamic = None # self._cached_X_tensor is the cached version of self._X_tensor.
self._cached_domain_dimension_dynamic = None self._cached_shape_tensor = None
self._cached_range_dimension_dynamic = None self._cached_batch_shape_tensor = None
self._cached_tensor_rank_dynamic = None self._cached_domain_dimension_tensor = None
self._cached_range_dimension_tensor = None
self._cached_tensor_rank_tensor = None
@contextlib.contextmanager @contextlib.contextmanager
def _name_scope(self, name=None, values=None): def _name_scope(self, name=None, values=None):
@ -240,10 +242,10 @@ class LinearOperator(object):
""" """
return self._shape() return self._shape()
def _shape_dynamic(self): def _shape_tensor(self):
raise NotImplementedError("_shape_dynamic is not implemented.") raise NotImplementedError("_shape_tensor is not implemented.")
def shape_dynamic(self, name="shape_dynamic"): def shape_tensor(self, name="shape_tensor"):
"""Shape of this `LinearOperator`, determined at runtime. """Shape of this `LinearOperator`, determined at runtime.
If this operator acts like the batch matrix `A` with If this operator acts like the batch matrix `A` with
@ -258,14 +260,14 @@ class LinearOperator(object):
""" """
with self._name_scope(name): with self._name_scope(name):
# Be clean by avoiding adding shape Ops to the graph too many times. # Be clean by avoiding adding shape Ops to the graph too many times.
if self._cached_shape_dynamic is None: if self._cached_shape_tensor is None:
# Prefer to use statically defined shape if available. # Prefer to use statically defined shape if available.
if self.shape.is_fully_defined(): if self.shape.is_fully_defined():
self._cached_shape_dynamic = linear_operator_util.shape_tensor( self._cached_shape_tensor = linear_operator_util.shape_tensor(
self.shape.as_list()) self.shape.as_list())
else: else:
self._cached_shape_dynamic = self._shape_dynamic() self._cached_shape_tensor = self._shape_tensor()
return self._cached_shape_dynamic return self._cached_shape_tensor
@property @property
def batch_shape(self): def batch_shape(self):
@ -281,7 +283,7 @@ class LinearOperator(object):
# Derived classes get this "for free" once .shape is implemented. # Derived classes get this "for free" once .shape is implemented.
return self.shape[:-2] return self.shape[:-2]
def batch_shape_dynamic(self, name="batch_shape_dynamic"): def batch_shape_tensor(self, name="batch_shape_tensor"):
"""Shape of batch dimensions of this operator, determined at runtime. """Shape of batch dimensions of this operator, determined at runtime.
If this operator acts like the batch matrix `A` with If this operator acts like the batch matrix `A` with
@ -296,14 +298,14 @@ class LinearOperator(object):
""" """
# Derived classes get this "for free" once .shape() is implemented. # Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name): with self._name_scope(name):
if self._cached_batch_shape_dynamic is None: if self._cached_batch_shape_tensor is None:
# Prefer to use statically defined shape if available. # Prefer to use statically defined shape if available.
if self.batch_shape.is_fully_defined(): if self.batch_shape.is_fully_defined():
self._cached_batch_shape_dynamic = linear_operator_util.shape_tensor( self._cached_batch_shape_tensor = linear_operator_util.shape_tensor(
self.batch_shape.as_list(), name="batch_shape") self.batch_shape.as_list(), name="batch_shape")
else: else:
self._cached_batch_shape_dynamic = self.shape_dynamic()[:-2] self._cached_batch_shape_tensor = self.shape_tensor()[:-2]
return self._cached_batch_shape_dynamic return self._cached_batch_shape_tensor
@property @property
def tensor_rank(self, name="tensor_rank"): def tensor_rank(self, name="tensor_rank"):
@ -322,7 +324,7 @@ class LinearOperator(object):
with self._name_scope(name): with self._name_scope(name):
return self.shape.ndims return self.shape.ndims
def tensor_rank_dynamic(self, name="tensor_rank_dynamic"): def tensor_rank_tensor(self, name="tensor_rank_tensor"):
"""Rank (in the sense of tensors) of matrix corresponding to this operator. """Rank (in the sense of tensors) of matrix corresponding to this operator.
If this operator acts like the batch matrix `A` with If this operator acts like the batch matrix `A` with
@ -336,15 +338,15 @@ class LinearOperator(object):
""" """
# Derived classes get this "for free" once .shape() is implemented. # Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name): with self._name_scope(name):
if self._cached_tensor_rank_dynamic is None: if self._cached_tensor_rank_tensor is None:
# Prefer to use statically defined shape if available. # Prefer to use statically defined shape if available.
if self.tensor_rank is not None: if self.tensor_rank is not None:
self._cached_tensor_rank_dynamic = ops.convert_to_tensor( self._cached_tensor_rank_tensor = ops.convert_to_tensor(
self.tensor_rank) self.tensor_rank)
else: else:
self._cached_tensor_rank_dynamic = array_ops.size( self._cached_tensor_rank_tensor = array_ops.size(
self.shape_dynamic()) self.shape_tensor())
return self._cached_tensor_rank_dynamic return self._cached_tensor_rank_tensor
@property @property
def domain_dimension(self): def domain_dimension(self):
@ -359,7 +361,7 @@ class LinearOperator(object):
# Derived classes get this "for free" once .shape is implemented. # Derived classes get this "for free" once .shape is implemented.
return self.shape[-1] return self.shape[-1]
def domain_dimension_dynamic(self, name="domain_dimension_dynamic"): def domain_dimension_tensor(self, name="domain_dimension_tensor"):
"""Dimension (in the sense of vector spaces) of the domain of this operator. """Dimension (in the sense of vector spaces) of the domain of this operator.
Determined at runtime. Determined at runtime.
@ -375,14 +377,14 @@ class LinearOperator(object):
""" """
# Derived classes get this "for free" once .shape() is implemented. # Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name): with self._name_scope(name):
if self._cached_domain_dimension_dynamic is None: if self._cached_domain_dimension_tensor is None:
# Prefer to use statically defined shape if available. # Prefer to use statically defined shape if available.
if self.domain_dimension.value is not None: if self.domain_dimension.value is not None:
self._cached_domain_dimension_dynamic = ops.convert_to_tensor( self._cached_domain_dimension_tensor = ops.convert_to_tensor(
self.domain_dimension.value) self.domain_dimension.value)
else: else:
self._cached_domain_dimension_dynamic = self.shape_dynamic()[-1] self._cached_domain_dimension_tensor = self.shape_tensor()[-1]
return self._cached_domain_dimension_dynamic return self._cached_domain_dimension_tensor
@property @property
def range_dimension(self): def range_dimension(self):
@ -397,7 +399,7 @@ class LinearOperator(object):
# Derived classes get this "for free" once .shape is implemented. # Derived classes get this "for free" once .shape is implemented.
return self.shape[-2] return self.shape[-2]
def range_dimension_dynamic(self, name="range_dimension_dynamic"): def range_dimension_tensor(self, name="range_dimension_tensor"):
"""Dimension (in the sense of vector spaces) of the range of this operator. """Dimension (in the sense of vector spaces) of the range of this operator.
Determined at runtime. Determined at runtime.
@ -413,14 +415,14 @@ class LinearOperator(object):
""" """
# Derived classes get this "for free" once .shape() is implemented. # Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name): with self._name_scope(name):
if self._cached_range_dimension_dynamic is None: if self._cached_range_dimension_tensor is None:
# Prefer to use statically defined shape if available. # Prefer to use statically defined shape if available.
if self.range_dimension.value is not None: if self.range_dimension.value is not None:
self._cached_range_dimension_dynamic = ops.convert_to_tensor( self._cached_range_dimension_tensor = ops.convert_to_tensor(
self.range_dimension.value) self.range_dimension.value)
else: else:
self._cached_range_dimension_dynamic = self.shape_dynamic()[-2] self._cached_range_dimension_tensor = self.shape_tensor()[-2]
return self._cached_range_dimension_dynamic return self._cached_range_dimension_tensor
def _assert_non_singular(self): def _assert_non_singular(self):
raise NotImplementedError("assert_non_singular is not implemented.") raise NotImplementedError("assert_non_singular is not implemented.")
@ -574,12 +576,12 @@ class LinearOperator(object):
if self.batch_shape.is_fully_defined(): if self.batch_shape.is_fully_defined():
batch_shape = self.batch_shape batch_shape = self.batch_shape
else: else:
batch_shape = self.batch_shape_dynamic() batch_shape = self.batch_shape_tensor()
if self.domain_dimension.value is not None: if self.domain_dimension.value is not None:
n = self.domain_dimension.value n = self.domain_dimension.value
else: else:
n = self.domain_dimension_dynamic() n = self.domain_dimension_tensor()
eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype) eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
return self.apply(eye) return self.apply(eye)

View File

@ -202,7 +202,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
return batch_shape.concatenate(matrix_shape) return batch_shape.concatenate(matrix_shape)
def _shape_dynamic(self): def _shape_tensor(self):
# Avoid messy broadcasting if possible. # Avoid messy broadcasting if possible.
if self.shape.is_fully_defined(): if self.shape.is_fully_defined():
return ops.convert_to_tensor( return ops.convert_to_tensor(
@ -212,14 +212,14 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
# the graph. Things will fail at runtime naturally if shapes are # the graph. Things will fail at runtime naturally if shapes are
# incompatible. # incompatible.
matrix_shape = array_ops.stack([ matrix_shape = array_ops.stack([
self.operators[0].range_dimension_dynamic(), self.operators[0].range_dimension_tensor(),
self.operators[-1].domain_dimension_dynamic() self.operators[-1].domain_dimension_tensor()
]) ])
# Dummy Tensor of zeros. Will never be materialized. # Dummy Tensor of zeros. Will never be materialized.
zeros = array_ops.zeros(shape=self.operators[0].batch_shape_dynamic()) zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
for operator in self.operators[1:]: for operator in self.operators[1:]:
zeros += array_ops.zeros(shape=operator.batch_shape_dynamic()) zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
batch_shape = array_ops.shape(zeros) batch_shape = array_ops.shape(zeros)
return array_ops.concat((batch_shape, matrix_shape), 0) return array_ops.concat((batch_shape, matrix_shape), 0)

View File

@ -166,7 +166,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
d_shape = self._diag.get_shape() d_shape = self._diag.get_shape()
return d_shape.concatenate(d_shape[-1:]) return d_shape.concatenate(d_shape[-1:])
def _shape_dynamic(self): def _shape_tensor(self):
d_shape = array_ops.shape(self._diag) d_shape = array_ops.shape(self._diag)
k = d_shape[-1] k = d_shape[-1]
return array_ops.concat((d_shape, [k]), 0) return array_ops.concat((d_shape, [k]), 0)

View File

@ -261,7 +261,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
batch_shape = tensor_shape.TensorShape(self._batch_shape_static) batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
return batch_shape.concatenate(matrix_shape) return batch_shape.concatenate(matrix_shape)
def _shape_dynamic(self): def _shape_tensor(self):
matrix_shape = array_ops.stack( matrix_shape = array_ops.stack(
(self._num_rows, self._num_rows), axis=0) (self._num_rows, self._num_rows), axis=0)
if self._batch_shape_arg is None: if self._batch_shape_arg is None:
@ -307,7 +307,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
# Dynamic broadcast: # Dynamic broadcast:
# Always add to an array of zeros, rather than using a "cond", since a # Always add to an array of zeros, rather than using a "cond", since a
# cond would require copying data from GPU --> CPU. # cond would require copying data from GPU --> CPU.
special_shape = array_ops.concat((self.batch_shape_dynamic(), [1, 1]), 0) special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0)
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
return x + zeros return x + zeros
@ -320,10 +320,10 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
return self._possibly_broadcast_batch_shape(x) return self._possibly_broadcast_batch_shape(x)
def _determinant(self): def _determinant(self):
return array_ops.ones(shape=self.batch_shape_dynamic(), dtype=self.dtype) return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype)
def _log_abs_determinant(self): def _log_abs_determinant(self):
return array_ops.zeros(shape=self.batch_shape_dynamic(), dtype=self.dtype) return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
def _solve(self, rhs, adjoint=False): def _solve(self, rhs, adjoint=False):
return self._apply(rhs) return self._apply(rhs)
@ -566,7 +566,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
batch_shape = self.multiplier.get_shape() batch_shape = self.multiplier.get_shape()
return batch_shape.concatenate(matrix_shape) return batch_shape.concatenate(matrix_shape)
def _shape_dynamic(self): def _shape_tensor(self):
matrix_shape = array_ops.stack( matrix_shape = array_ops.stack(
(self._num_rows, self._num_rows), axis=0) (self._num_rows, self._num_rows), axis=0)

View File

@ -157,7 +157,7 @@ class LinearOperatorMatrix(linear_operator.LinearOperator):
def _shape(self): def _shape(self):
return self._matrix.get_shape() return self._matrix.get_shape()
def _shape_dynamic(self): def _shape_tensor(self):
return array_ops.shape(self._matrix) return array_ops.shape(self._matrix)
def _apply(self, x, adjoint=False): def _apply(self, x, adjoint=False):

View File

@ -262,8 +262,8 @@ class SquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
n = operator.domain_dimension.value n = operator.domain_dimension.value
x_shape = batch_shape + [n, r] x_shape = batch_shape + [n, r]
else: else:
batch_shape = operator.batch_shape_dynamic() batch_shape = operator.batch_shape_tensor()
n = operator.domain_dimension_dynamic() n = operator.domain_dimension_tensor()
x_shape = array_ops.concat((batch_shape, [n, r]), 0) x_shape = array_ops.concat((batch_shape, [n, r]), 0)
return random_normal(x_shape, dtype=operator.dtype) return random_normal(x_shape, dtype=operator.dtype)
@ -316,11 +316,11 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
n = operator.domain_dimension.value n = operator.domain_dimension.value
x_shape = batch_shape + [n, r] x_shape = batch_shape + [n, r]
else: else:
batch_shape = operator.batch_shape_dynamic() batch_shape = operator.batch_shape_tensor()
if adjoint: if adjoint:
n = operator.range_dimension_dynamic() n = operator.range_dimension_tensor()
else: else:
n = operator.domain_dimension_dynamic() n = operator.domain_dimension_tensor()
x_shape = array_ops.concat((batch_shape, [n, r]), 0) x_shape = array_ops.concat((batch_shape, [n, r]), 0)
return random_normal(x_shape, dtype=operator.dtype) return random_normal(x_shape, dtype=operator.dtype)

View File

@ -157,7 +157,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
def _shape(self): def _shape(self):
return self._tril.get_shape() return self._tril.get_shape()
def _shape_dynamic(self): def _shape_tensor(self):
return array_ops.shape(self._tril) return array_ops.shape(self._tril)
def _assert_non_singular(self): def _assert_non_singular(self):

View File

@ -83,10 +83,10 @@ def assert_compatible_matrix_dimensions(operator, x):
Returns: Returns:
`Assert` `Op`. `Assert` `Op`.
""" """
# Static checks are done in the base class. Only dynamic asserts here. # Static checks are done in the base class. Only tensor asserts here.
assert_same_dd = check_ops.assert_equal( assert_same_dd = check_ops.assert_equal(
array_ops.shape(x)[-2], array_ops.shape(x)[-2],
operator.domain_dimension_dynamic(), operator.domain_dimension_tensor(),
message=( message=(
"Incompatible matrix dimensions. " "Incompatible matrix dimensions. "
"shape[-2] of argument to be the same as this operator")) "shape[-2] of argument to be the same as this operator"))