diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py index 2c14d4021db..2ef790f9a26 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py @@ -287,24 +287,34 @@ class LinearOperatorLowRankUpdateBroadcastsShape(test.TestCase): @test_util.run_deprecated_v1 def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self): num_rows_ph = array_ops.placeholder(dtypes.int32) - base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph) u_shape_ph = array_ops.placeholder(dtypes.int32) u = array_ops.ones(shape=u_shape_ph) - operator = linalg.LinearOperatorLowRankUpdate(base_operator, u) + v_shape_ph = array_ops.placeholder(dtypes.int32) + v = array_ops.ones(shape=v_shape_ph) + + diag_shape_ph = array_ops.placeholder(dtypes.int32) + diag_update = array_ops.ones(shape=diag_shape_ph) + + operator = linalg.LinearOperatorLowRankUpdate(base_operator, + u=u, + diag_update=diag_update, + v=v) feed_dict = { num_rows_ph: 3, - u_shape_ph: [2, 3, 2], # batch_shape = [2] + u_shape_ph: [1, 1, 2, 3, 2], # batch_shape = [1, 1, 2] + v_shape_ph: [1, 2, 1, 3, 2], # batch_shape = [1, 2, 1] + diag_shape_ph: [2, 1, 1, 2] # batch_shape = [2, 1, 1] } with self.cached_session(): shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict) - self.assertAllEqual([2, 3, 3], shape_tensor) + self.assertAllEqual([2, 2, 2, 3, 3], shape_tensor) dense = operator.to_dense().eval(feed_dict=feed_dict) - self.assertAllEqual([2, 3, 3], dense.shape) + self.assertAllEqual([2, 2, 2, 3, 3], dense.shape) def test_u_and_v_incompatible_batch_shape_raises(self): base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64) diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py index 4157233e904..2e60a10e226 100644 --- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py +++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py @@ -339,12 +339,21 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): def _shape(self): batch_shape = array_ops.broadcast_static_shape( self.base_operator.batch_shape, + self.diag_operator.batch_shape) + batch_shape = array_ops.broadcast_static_shape( + batch_shape, self.u.shape[:-2]) + batch_shape = array_ops.broadcast_static_shape( + batch_shape, + self.v.shape[:-2]) return batch_shape.concatenate(self.base_operator.shape[-2:]) def _shape_tensor(self): batch_shape = array_ops.broadcast_dynamic_shape( self.base_operator.batch_shape_tensor(), + self.diag_operator.batch_shape_tensor()) + batch_shape = array_ops.broadcast_dynamic_shape( + batch_shape, array_ops.shape(self.u)[:-2]) batch_shape = array_ops.broadcast_dynamic_shape( batch_shape,