Include all parameters in LinearOperatorLowRankUpdate batch shape calculation.

PiperOrigin-RevId: 350756760
Change-Id: Ie8b278c4b7f44220d571f2239cc377c3535b25ff
This commit is contained in:
Dave Moore 2021-01-08 07:07:53 -08:00 committed by TensorFlower Gardener
parent 97ba838a81
commit 59974d69d2
2 changed files with 24 additions and 5 deletions

View File

@ -287,24 +287,34 @@ class LinearOperatorLowRankUpdateBroadcastsShape(test.TestCase):
@test_util.run_deprecated_v1
def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self):
num_rows_ph = array_ops.placeholder(dtypes.int32)
base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph)
u_shape_ph = array_ops.placeholder(dtypes.int32)
u = array_ops.ones(shape=u_shape_ph)
operator = linalg.LinearOperatorLowRankUpdate(base_operator, u)
v_shape_ph = array_ops.placeholder(dtypes.int32)
v = array_ops.ones(shape=v_shape_ph)
diag_shape_ph = array_ops.placeholder(dtypes.int32)
diag_update = array_ops.ones(shape=diag_shape_ph)
operator = linalg.LinearOperatorLowRankUpdate(base_operator,
u=u,
diag_update=diag_update,
v=v)
feed_dict = {
num_rows_ph: 3,
u_shape_ph: [2, 3, 2], # batch_shape = [2]
u_shape_ph: [1, 1, 2, 3, 2], # batch_shape = [1, 1, 2]
v_shape_ph: [1, 2, 1, 3, 2], # batch_shape = [1, 2, 1]
diag_shape_ph: [2, 1, 1, 2] # batch_shape = [2, 1, 1]
}
with self.cached_session():
shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict)
self.assertAllEqual([2, 3, 3], shape_tensor)
self.assertAllEqual([2, 2, 2, 3, 3], shape_tensor)
dense = operator.to_dense().eval(feed_dict=feed_dict)
self.assertAllEqual([2, 3, 3], dense.shape)
self.assertAllEqual([2, 2, 2, 3, 3], dense.shape)
def test_u_and_v_incompatible_batch_shape_raises(self):
base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)

View File

@ -339,12 +339,21 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
def _shape(self):
batch_shape = array_ops.broadcast_static_shape(
self.base_operator.batch_shape,
self.diag_operator.batch_shape)
batch_shape = array_ops.broadcast_static_shape(
batch_shape,
self.u.shape[:-2])
batch_shape = array_ops.broadcast_static_shape(
batch_shape,
self.v.shape[:-2])
return batch_shape.concatenate(self.base_operator.shape[-2:])
def _shape_tensor(self):
batch_shape = array_ops.broadcast_dynamic_shape(
self.base_operator.batch_shape_tensor(),
self.diag_operator.batch_shape_tensor())
batch_shape = array_ops.broadcast_dynamic_shape(
batch_shape,
array_ops.shape(self.u)[:-2])
batch_shape = array_ops.broadcast_dynamic_shape(
batch_shape,