Include all parameters in LinearOperatorLowRankUpdate batch shape calculation.
PiperOrigin-RevId: 350756760 Change-Id: Ie8b278c4b7f44220d571f2239cc377c3535b25ff
This commit is contained in:
parent
97ba838a81
commit
59974d69d2
@ -287,24 +287,34 @@ class LinearOperatorLowRankUpdateBroadcastsShape(test.TestCase):
|
||||
@test_util.run_deprecated_v1
|
||||
def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self):
|
||||
num_rows_ph = array_ops.placeholder(dtypes.int32)
|
||||
|
||||
base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph)
|
||||
|
||||
u_shape_ph = array_ops.placeholder(dtypes.int32)
|
||||
u = array_ops.ones(shape=u_shape_ph)
|
||||
|
||||
operator = linalg.LinearOperatorLowRankUpdate(base_operator, u)
|
||||
v_shape_ph = array_ops.placeholder(dtypes.int32)
|
||||
v = array_ops.ones(shape=v_shape_ph)
|
||||
|
||||
diag_shape_ph = array_ops.placeholder(dtypes.int32)
|
||||
diag_update = array_ops.ones(shape=diag_shape_ph)
|
||||
|
||||
operator = linalg.LinearOperatorLowRankUpdate(base_operator,
|
||||
u=u,
|
||||
diag_update=diag_update,
|
||||
v=v)
|
||||
|
||||
feed_dict = {
|
||||
num_rows_ph: 3,
|
||||
u_shape_ph: [2, 3, 2], # batch_shape = [2]
|
||||
u_shape_ph: [1, 1, 2, 3, 2], # batch_shape = [1, 1, 2]
|
||||
v_shape_ph: [1, 2, 1, 3, 2], # batch_shape = [1, 2, 1]
|
||||
diag_shape_ph: [2, 1, 1, 2] # batch_shape = [2, 1, 1]
|
||||
}
|
||||
|
||||
with self.cached_session():
|
||||
shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict)
|
||||
self.assertAllEqual([2, 3, 3], shape_tensor)
|
||||
self.assertAllEqual([2, 2, 2, 3, 3], shape_tensor)
|
||||
dense = operator.to_dense().eval(feed_dict=feed_dict)
|
||||
self.assertAllEqual([2, 3, 3], dense.shape)
|
||||
self.assertAllEqual([2, 2, 2, 3, 3], dense.shape)
|
||||
|
||||
def test_u_and_v_incompatible_batch_shape_raises(self):
|
||||
base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
|
||||
|
@ -339,12 +339,21 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
def _shape(self):
|
||||
batch_shape = array_ops.broadcast_static_shape(
|
||||
self.base_operator.batch_shape,
|
||||
self.diag_operator.batch_shape)
|
||||
batch_shape = array_ops.broadcast_static_shape(
|
||||
batch_shape,
|
||||
self.u.shape[:-2])
|
||||
batch_shape = array_ops.broadcast_static_shape(
|
||||
batch_shape,
|
||||
self.v.shape[:-2])
|
||||
return batch_shape.concatenate(self.base_operator.shape[-2:])
|
||||
|
||||
def _shape_tensor(self):
|
||||
batch_shape = array_ops.broadcast_dynamic_shape(
|
||||
self.base_operator.batch_shape_tensor(),
|
||||
self.diag_operator.batch_shape_tensor())
|
||||
batch_shape = array_ops.broadcast_dynamic_shape(
|
||||
batch_shape,
|
||||
array_ops.shape(self.u)[:-2])
|
||||
batch_shape = array_ops.broadcast_dynamic_shape(
|
||||
batch_shape,
|
||||
|
Loading…
Reference in New Issue
Block a user