Include all parameters in LinearOperatorLowRankUpdate batch shape calculation.
PiperOrigin-RevId: 350756760 Change-Id: Ie8b278c4b7f44220d571f2239cc377c3535b25ff
This commit is contained in:
parent
97ba838a81
commit
59974d69d2
@ -287,24 +287,34 @@ class LinearOperatorLowRankUpdateBroadcastsShape(test.TestCase):
|
|||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self):
|
def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self):
|
||||||
num_rows_ph = array_ops.placeholder(dtypes.int32)
|
num_rows_ph = array_ops.placeholder(dtypes.int32)
|
||||||
|
|
||||||
base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph)
|
base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph)
|
||||||
|
|
||||||
u_shape_ph = array_ops.placeholder(dtypes.int32)
|
u_shape_ph = array_ops.placeholder(dtypes.int32)
|
||||||
u = array_ops.ones(shape=u_shape_ph)
|
u = array_ops.ones(shape=u_shape_ph)
|
||||||
|
|
||||||
operator = linalg.LinearOperatorLowRankUpdate(base_operator, u)
|
v_shape_ph = array_ops.placeholder(dtypes.int32)
|
||||||
|
v = array_ops.ones(shape=v_shape_ph)
|
||||||
|
|
||||||
|
diag_shape_ph = array_ops.placeholder(dtypes.int32)
|
||||||
|
diag_update = array_ops.ones(shape=diag_shape_ph)
|
||||||
|
|
||||||
|
operator = linalg.LinearOperatorLowRankUpdate(base_operator,
|
||||||
|
u=u,
|
||||||
|
diag_update=diag_update,
|
||||||
|
v=v)
|
||||||
|
|
||||||
feed_dict = {
|
feed_dict = {
|
||||||
num_rows_ph: 3,
|
num_rows_ph: 3,
|
||||||
u_shape_ph: [2, 3, 2], # batch_shape = [2]
|
u_shape_ph: [1, 1, 2, 3, 2], # batch_shape = [1, 1, 2]
|
||||||
|
v_shape_ph: [1, 2, 1, 3, 2], # batch_shape = [1, 2, 1]
|
||||||
|
diag_shape_ph: [2, 1, 1, 2] # batch_shape = [2, 1, 1]
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict)
|
shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict)
|
||||||
self.assertAllEqual([2, 3, 3], shape_tensor)
|
self.assertAllEqual([2, 2, 2, 3, 3], shape_tensor)
|
||||||
dense = operator.to_dense().eval(feed_dict=feed_dict)
|
dense = operator.to_dense().eval(feed_dict=feed_dict)
|
||||||
self.assertAllEqual([2, 3, 3], dense.shape)
|
self.assertAllEqual([2, 2, 2, 3, 3], dense.shape)
|
||||||
|
|
||||||
def test_u_and_v_incompatible_batch_shape_raises(self):
|
def test_u_and_v_incompatible_batch_shape_raises(self):
|
||||||
base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
|
base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
|
||||||
|
|||||||
@ -339,12 +339,21 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
|||||||
def _shape(self):
|
def _shape(self):
|
||||||
batch_shape = array_ops.broadcast_static_shape(
|
batch_shape = array_ops.broadcast_static_shape(
|
||||||
self.base_operator.batch_shape,
|
self.base_operator.batch_shape,
|
||||||
|
self.diag_operator.batch_shape)
|
||||||
|
batch_shape = array_ops.broadcast_static_shape(
|
||||||
|
batch_shape,
|
||||||
self.u.shape[:-2])
|
self.u.shape[:-2])
|
||||||
|
batch_shape = array_ops.broadcast_static_shape(
|
||||||
|
batch_shape,
|
||||||
|
self.v.shape[:-2])
|
||||||
return batch_shape.concatenate(self.base_operator.shape[-2:])
|
return batch_shape.concatenate(self.base_operator.shape[-2:])
|
||||||
|
|
||||||
def _shape_tensor(self):
|
def _shape_tensor(self):
|
||||||
batch_shape = array_ops.broadcast_dynamic_shape(
|
batch_shape = array_ops.broadcast_dynamic_shape(
|
||||||
self.base_operator.batch_shape_tensor(),
|
self.base_operator.batch_shape_tensor(),
|
||||||
|
self.diag_operator.batch_shape_tensor())
|
||||||
|
batch_shape = array_ops.broadcast_dynamic_shape(
|
||||||
|
batch_shape,
|
||||||
array_ops.shape(self.u)[:-2])
|
array_ops.shape(self.u)[:-2])
|
||||||
batch_shape = array_ops.broadcast_dynamic_shape(
|
batch_shape = array_ops.broadcast_dynamic_shape(
|
||||||
batch_shape,
|
batch_shape,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user