s/get_shape()/shape in linalg. This is the TF2 way.

PiperOrigin-RevId: 264514950
This commit is contained in:
Ian Langmore 2019-08-20 18:41:20 -07:00 committed by TensorFlower Gardener
parent 25aba97589
commit c087e10678
16 changed files with 72 additions and 72 deletions

View File

@ -339,8 +339,8 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
h = operator.convolution_kernel()
c = operator.to_dense()
self.assertAllEqual((2, 3), h.get_shape())
self.assertAllEqual((2, 3, 3), c.get_shape())
self.assertAllEqual((2, 3), h.shape)
self.assertAllEqual((2, 3, 3), c.shape)
self.assertAllClose(h.eval(), self.evaluate(c)[:, :, 0])
@test_util.run_deprecated_v1

View File

@ -145,16 +145,16 @@ class LinearOperatorDiagTest(
# Create a batch matrix with the broadcast shape of operator.
diag_broadcast = array_ops.concat((diag, diag), 1)
mat = array_ops.matrix_diag(diag_broadcast)
self.assertAllEqual((2, 2, 3, 3), mat.get_shape()) # being pedantic.
self.assertAllEqual((2, 2, 3, 3), mat.shape) # being pedantic.
operator_matmul = operator.matmul(x)
mat_matmul = math_ops.matmul(mat, x)
self.assertAllEqual(operator_matmul.get_shape(), mat_matmul.get_shape())
self.assertAllEqual(operator_matmul.shape, mat_matmul.shape)
self.assertAllClose(*self.evaluate([operator_matmul, mat_matmul]))
operator_solve = operator.solve(x)
mat_solve = linalg_ops.matrix_solve(mat, x)
self.assertAllEqual(operator_solve.get_shape(), mat_solve.get_shape())
self.assertAllEqual(operator_solve.shape, mat_solve.shape)
self.assertAllClose(*self.evaluate([operator_solve, mat_solve]))
def test_diag_matmul(self):

View File

@ -173,7 +173,7 @@ class LinearOperatorIdentityTest(
operator_matmul = operator.matmul(x)
expected = x
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
self.assertAllEqual(operator_matmul.shape, expected.shape)
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
def test_default_batch_shape_broadcasts_with_everything_dynamic(self):
@ -207,7 +207,7 @@ class LinearOperatorIdentityTest(
expected = x + zeros
operator_matmul = operator.matmul(x)
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
self.assertAllEqual(operator_matmul.shape, expected.shape)
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
def test_broadcast_matmul_dynamic_shapes(self):
@ -423,13 +423,13 @@ class LinearOperatorScaledIdentityTest(
# Test matmul
expected = x * 2.2 + zeros
operator_matmul = operator.matmul(x)
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
self.assertAllEqual(operator_matmul.shape, expected.shape)
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
# Test solve
expected = x / 2.2 + zeros
operator_solve = operator.solve(x)
self.assertAllEqual(operator_solve.get_shape(), expected.get_shape())
self.assertAllEqual(operator_solve.shape, expected.shape)
self.assertAllClose(*self.evaluate([operator_solve, expected]))
def test_broadcast_matmul_and_solve_scalar_scale_multiplier(self):
@ -449,13 +449,13 @@ class LinearOperatorScaledIdentityTest(
# Test matmul
expected = x * 2.2
operator_matmul = operator.matmul(x)
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
self.assertAllEqual(operator_matmul.shape, expected.shape)
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
# Test solve
expected = x / 2.2
operator_solve = operator.solve(x)
self.assertAllEqual(operator_solve.get_shape(), expected.get_shape())
self.assertAllEqual(operator_solve.shape, expected.shape)
self.assertAllClose(*self.evaluate([operator_solve, expected]))
def test_is_x_flags(self):

View File

@ -80,7 +80,7 @@ class LinearOperatorMatmulSolve(linalg.LinearOperator):
is_square=is_square)
def _shape(self):
return self._matrix.get_shape()
return self._matrix.shape
def _shape_tensor(self):
return array_ops.shape(self._matrix)
@ -136,7 +136,7 @@ class LinearOperatorTest(test.TestCase):
operator = LinearOperatorMatmulSolve(matrix)
with self.cached_session():
operator_dense = operator.to_dense()
self.assertAllEqual((2, 3, 4), operator_dense.get_shape())
self.assertAllEqual((2, 3, 4), operator_dense.shape)
self.assertAllClose(matrix, self.evaluate(operator_dense))
def test_generic_to_dense_method_non_square_matrix_tensor(self):
@ -152,7 +152,7 @@ class LinearOperatorTest(test.TestCase):
x = [1., 1.]
with self.cached_session():
y = operator.matvec(x)
self.assertAllEqual((2,), y.get_shape())
self.assertAllEqual((2,), y.shape)
self.assertAllClose([1., 2.], self.evaluate(y))
def test_solvevec(self):
@ -161,7 +161,7 @@ class LinearOperatorTest(test.TestCase):
y = [1., 1.]
with self.cached_session():
x = operator.solvevec(y)
self.assertAllEqual((2,), x.get_shape())
self.assertAllEqual((2,), x.shape)
self.assertAllClose([1., 1 / 2.], self.evaluate(x))
def test_is_square_set_to_true_for_square_static_shapes(self):

View File

@ -115,8 +115,8 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
self.assertAllEqual(x_bc_expected.shape, x_bc.shape)
self.assertAllEqual(y_bc_expected.shape, y_bc.shape)
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
@ -133,8 +133,8 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
self.assertAllEqual(x_bc_expected.shape, x_bc.shape)
self.assertAllEqual(y_bc_expected.shape, y_bc.shape)
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
@ -197,7 +197,7 @@ class CholeskySolveWithBroadcastTest(test.TestCase):
chol_broadcast = chol + np.zeros((2, 1, 1))
result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
self.assertAllEqual((2, 3, 7), result.shape)
expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
self.assertAllClose(*self.evaluate([expected, result]))
@ -227,7 +227,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
rhs_broadcast = rhs + np.zeros((2, 1, 1))
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
self.assertAllEqual((2, 3, 7), result.shape)
expected = linalg_ops.matrix_solve(matrix, rhs_broadcast)
self.assertAllClose(*self.evaluate([expected, result]))
@ -244,7 +244,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
matrix_broadcast = matrix + np.zeros((2, 1, 1))
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
self.assertAllEqual((2, 3, 2), result.get_shape())
self.assertAllEqual((2, 3, 2), result.shape)
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
self.assertAllClose(*self.evaluate([expected, result]))
@ -282,7 +282,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
result = linear_operator_util.matrix_solve_with_broadcast(
matrix, rhs, adjoint=True)
self.assertAllEqual((2, 3, 2), result.get_shape())
self.assertAllEqual((2, 3, 2), result.shape)
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs, adjoint=True)
self.assertAllClose(*self.evaluate([expected, result]))
@ -313,7 +313,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
self.assertAllEqual((2, 3, 7), result.shape)
expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
self.assertAllClose(*self.evaluate([expected, result]))
@ -331,7 +331,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 2), result.get_shape())
self.assertAllEqual((2, 3, 2), result.shape)
expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
self.assertAllClose(*self.evaluate([expected, result]))
@ -349,7 +349,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs, adjoint=True)
self.assertAllEqual((2, 3, 2), result.get_shape())
self.assertAllEqual((2, 3, 2), result.shape)
expected = linalg_ops.matrix_triangular_solve(
matrix_broadcast, rhs, adjoint=True)
self.assertAllClose(*self.evaluate([expected, result]))

View File

@ -137,8 +137,8 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
"""Static check of spectrum. Then return `Tensor` version."""
spectrum = ops.convert_to_tensor(spectrum, name="spectrum")
if spectrum.get_shape().ndims is not None:
if spectrum.get_shape().ndims < self.block_depth:
if spectrum.shape.ndims is not None:
if spectrum.shape.ndims < self.block_depth:
raise ValueError(
"Argument spectrum must have at least %d dimensions. Found: %s" %
(self.block_depth, spectrum))
@ -183,7 +183,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
@property
def block_shape(self):
return self.spectrum.get_shape()[-self.block_depth:]
return self.spectrum.shape[-self.block_depth:]
@property
def spectrum(self):
@ -207,11 +207,11 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
# Blockify: Blockfy trailing dimensions.
# [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1]
if (vec.get_shape().is_fully_defined() and
if (vec.shape.is_fully_defined() and
self.block_shape.is_fully_defined()):
# vec_leading_shape = [m3, m0, m1],
# the parts of vec that will not be blockified.
vec_leading_shape = vec.get_shape()[:-1]
vec_leading_shape = vec.shape[:-1]
final_shape = vec_leading_shape.concatenate(self.block_shape)
else:
vec_leading_shape = array_ops.shape(vec)[:-1]
@ -232,9 +232,9 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
# Un-blockify: Flatten block dimensions. Reshape
# [v0, v1, v2, v3] --> [v0, v1, v2*v3].
if vec.get_shape().is_fully_defined():
if vec.shape.is_fully_defined():
# vec_shape = [v0, v1, v2, v3]
vec_shape = vec.get_shape().as_list()
vec_shape = vec.shape.as_list()
# vec_leading_shape = [v0, v1]
vec_leading_shape = vec_shape[:-self.block_depth]
# vec_block_shape = [v2, v3]
@ -298,7 +298,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
return math_ops.cast(h, self.dtype)
def _shape(self):
s_shape = self._spectrum.get_shape()
s_shape = self._spectrum.shape
# Suppose spectrum.shape = [a, b, c, d]
# block_depth = 2
# Then:
@ -471,8 +471,8 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
# Get shape of diag along with the axis over which to reduce the spectrum.
# We will reduce the spectrum over all block indices.
if self.spectrum.get_shape().is_fully_defined():
spec_rank = self.spectrum.get_shape().ndims
if self.spectrum.shape.is_fully_defined():
spec_rank = self.spectrum.shape.ndims
axis = np.arange(spec_rank - self.block_depth, spec_rank, dtype=np.int32)
else:
spec_rank = array_ops.rank(self.spectrum)

View File

@ -167,13 +167,13 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
def _check_diag(self, diag):
"""Static check of diag."""
if diag.get_shape().ndims is not None and diag.get_shape().ndims < 1:
if diag.shape.ndims is not None and diag.shape.ndims < 1:
raise ValueError("Argument diag must have at least 1 dimension. "
"Found: %s" % diag)
def _shape(self):
# If d_shape = [5, 3], we return [5, 3, 3].
d_shape = self._diag.get_shape()
d_shape = self._diag.shape
return d_shape.concatenate(d_shape[-1:])
def _shape_tensor(self):

View File

@ -166,13 +166,13 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
"Argument matrix must have dtype in %s. Found: %s"
% (allowed_dtypes, dtype))
if matrix.get_shape().ndims is not None and matrix.get_shape().ndims < 2:
if matrix.shape.ndims is not None and matrix.shape.ndims < 2:
raise ValueError(
"Argument matrix must have at least 2 dimensions. Found: %s"
% matrix)
def _shape(self):
return self._matrix.get_shape()
return self._matrix.shape
def _shape_tensor(self):
return array_ops.shape(self._matrix)

View File

@ -155,15 +155,15 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
def _check_reflection_axis(self, reflection_axis):
"""Static check of reflection_axis."""
if (reflection_axis.get_shape().ndims is not None and
reflection_axis.get_shape().ndims < 1):
if (reflection_axis.shape.ndims is not None and
reflection_axis.shape.ndims < 1):
raise ValueError(
"Argument reflection_axis must have at least 1 dimension. "
"Found: %s" % reflection_axis)
def _shape(self):
# If d_shape = [5, 3], we return [5, 3, 3].
d_shape = self._reflection_axis.get_shape()
d_shape = self._reflection_axis.shape
return d_shape.concatenate(d_shape[-1:])
def _shape_tensor(self):

View File

@ -333,10 +333,10 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
# Also, the final dimension of 'x' can have any shape.
# Therefore, the final two dimensions of special_shape are 1's.
special_shape = self.batch_shape.concatenate([1, 1])
bshape = array_ops.broadcast_static_shape(x.get_shape(), special_shape)
bshape = array_ops.broadcast_static_shape(x.shape, special_shape)
if special_shape.is_fully_defined():
# bshape.is_fully_defined iff special_shape.is_fully_defined.
if bshape == x.get_shape():
if bshape == x.shape:
return x
# Use the built in broadcasting of addition.
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
@ -628,7 +628,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
matrix_shape = tensor_shape.TensorShape((self._num_rows_static,
self._num_rows_static))
batch_shape = self.multiplier.get_shape()
batch_shape = self.multiplier.shape
return batch_shape.concatenate(matrix_shape)
def _shape_tensor(self):

View File

@ -271,7 +271,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
"""Static check that shapes are compatible."""
# Broadcast shape also checks that u and v are compatible.
uv_shape = array_ops.broadcast_static_shape(
self.u.get_shape(), self.v.get_shape())
self.u.shape, self.v.shape)
batch_shape = array_ops.broadcast_static_shape(
self.base_operator.batch_shape, uv_shape[:-2])
@ -282,9 +282,9 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
if self._diag_update is not None:
tensor_shape.dimension_at_index(uv_shape, -1).assert_is_compatible_with(
self._diag_update.get_shape()[-1])
self._diag_update.shape[-1])
array_ops.broadcast_static_shape(
batch_shape, self._diag_update.get_shape()[:-1])
batch_shape, self._diag_update.shape[:-1])
def _set_diag_operators(self, diag_update, is_diag_update_positive):
"""Set attributes self._diag_update and self._diag_operator."""
@ -335,7 +335,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
def _shape(self):
batch_shape = array_ops.broadcast_static_shape(
self.base_operator.batch_shape,
self.u.get_shape()[:-2])
self.u.shape[:-2])
return batch_shape.concatenate(self.base_operator.shape[-2:])
def _shape_tensor(self):

View File

@ -160,7 +160,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
def _check_tril(self, tril):
"""Static check of the `tril` argument."""
if tril.get_shape().ndims is not None and tril.get_shape().ndims < 2:
if tril.shape.ndims is not None and tril.shape.ndims < 2:
raise ValueError(
"Argument tril must have at least 2 dimensions. Found: %s"
% tril)
@ -174,7 +174,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
return array_ops.matrix_diag_part(self._tril)
def _shape(self):
return self._tril.get_shape()
return self._tril.shape
def _shape_tensor(self):
return array_ops.shape(self._tril)

View File

@ -272,7 +272,7 @@ def _test_to_dense(use_placeholder, shapes_info, dtype):
shapes_info, dtype, use_placeholder=use_placeholder)
op_dense = operator.to_dense()
if not use_placeholder:
self.assertAllEqual(shapes_info.shape, op_dense.get_shape())
self.assertAllEqual(shapes_info.shape, op_dense.shape)
op_dense_v, mat_v = sess.run([op_dense, mat])
self.assertAC(op_dense_v, mat_v)
return test_to_dense
@ -286,7 +286,7 @@ def _test_det(use_placeholder, shapes_info, dtype):
shapes_info, dtype, use_placeholder=use_placeholder)
op_det = operator.determinant()
if not use_placeholder:
self.assertAllEqual(shapes_info.shape[:-2], op_det.get_shape())
self.assertAllEqual(shapes_info.shape[:-2], op_det.shape)
op_det_v, mat_det_v = sess.run(
[op_det, linalg_ops.matrix_determinant(mat)])
self.assertAC(op_det_v, mat_det_v)
@ -303,7 +303,7 @@ def _test_log_abs_det(use_placeholder, shapes_info, dtype):
_, mat_log_abs_det = linalg.slogdet(mat)
if not use_placeholder:
self.assertAllEqual(
shapes_info.shape[:-2], op_log_abs_det.get_shape())
shapes_info.shape[:-2], op_log_abs_det.shape)
op_log_abs_det_v, mat_log_abs_det_v = sess.run(
[op_log_abs_det, mat_log_abs_det])
self.assertAC(op_log_abs_det_v, mat_log_abs_det_v)
@ -340,8 +340,8 @@ def _test_matmul_base(
op_matmul = operator.matmul(x, adjoint=adjoint)
mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint)
if not use_placeholder:
self.assertAllEqual(op_matmul.get_shape(),
mat_matmul.get_shape())
self.assertAllEqual(op_matmul.shape,
mat_matmul.shape)
op_matmul_v, mat_matmul_v = sess.run(
[op_matmul, mat_matmul])
self.assertAC(op_matmul_v, mat_matmul_v)
@ -445,8 +445,8 @@ def _test_solve_base(
mat_solve = linear_operator_util.matrix_solve_with_broadcast(
mat, rhs, adjoint=adjoint)
if not use_placeholder:
self.assertAllEqual(op_solve.get_shape(),
mat_solve.get_shape())
self.assertAllEqual(op_solve.shape,
mat_solve.shape)
op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
self.assertAC(op_solve_v, mat_solve_v)
@ -500,7 +500,7 @@ def _test_trace(use_placeholder, shapes_info, dtype):
op_trace = operator.trace()
mat_trace = math_ops.trace(mat)
if not use_placeholder:
self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape())
self.assertAllEqual(op_trace.shape, mat_trace.shape)
op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace])
self.assertAC(op_trace_v, mat_trace_v)
return test_trace
@ -515,7 +515,7 @@ def _test_add_to_tensor(use_placeholder, shapes_info, dtype):
op_plus_2mat = operator.add_to_tensor(2 * mat)
if not use_placeholder:
self.assertAllEqual(shapes_info.shape, op_plus_2mat.get_shape())
self.assertAllEqual(shapes_info.shape, op_plus_2mat.shape)
op_plus_2mat_v, mat_v = sess.run([op_plus_2mat, mat])
@ -533,8 +533,8 @@ def _test_diag_part(use_placeholder, shapes_info, dtype):
mat_diag_part = array_ops.matrix_diag_part(mat)
if not use_placeholder:
self.assertAllEqual(mat_diag_part.get_shape(),
op_diag_part.get_shape())
self.assertAllEqual(mat_diag_part.shape,
op_diag_part.shape)
op_diag_part_, mat_diag_part_ = sess.run(
[op_diag_part, mat_diag_part])

View File

@ -168,12 +168,12 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
def _check_row_col(self, row, col):
"""Static check of row and column."""
for name, tensor in [["row", row], ["col", col]]:
if tensor.get_shape().ndims is not None and tensor.get_shape().ndims < 1:
if tensor.shape.ndims is not None and tensor.shape.ndims < 1:
raise ValueError("Argument {} must have at least 1 dimension. "
"Found: {}".format(name, tensor))
if row.get_shape()[-1] is not None and col.get_shape()[-1] is not None:
if row.get_shape()[-1] != col.get_shape()[-1]:
if row.shape[-1] is not None and col.shape[-1] is not None:
if row.shape[-1] != col.shape[-1]:
raise ValueError(
"Expected square matrix, got row and col with mismatched "
"dimensions.")

View File

@ -239,7 +239,7 @@ def assert_compatible_matrix_dimensions(operator, x):
def assert_is_batch_matrix(tensor):
"""Static assert that `tensor` has rank `2` or higher."""
sh = tensor.get_shape()
sh = tensor.shape
if sh.ndims is not None and sh.ndims < 2:
raise ValueError(
"Expected [batch] matrix to have at least two dimensions. Found: "
@ -327,14 +327,14 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
# x.shape = [2, j, k] (batch shape = [2])
# y.shape = [3, 1, l, m] (batch shape = [3, 1])
# ==> bcast_batch_shape = [3, 2]
bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
bcast_batch_shape = batch_matrices[0].shape[:-2]
for mat in batch_matrices[1:]:
bcast_batch_shape = array_ops.broadcast_static_shape(
bcast_batch_shape,
mat.get_shape()[:-2])
mat.shape[:-2])
if bcast_batch_shape.is_fully_defined():
for i, mat in enumerate(batch_matrices):
if mat.get_shape()[:-2] != bcast_batch_shape:
if mat.shape[:-2] != bcast_batch_shape:
bcast_shape = array_ops.concat(
[bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0)
batch_matrices[i] = array_ops.broadcast_to(mat, bcast_shape)

View File

@ -277,10 +277,10 @@ class LinearOperatorZeros(linear_operator.LinearOperator):
# Also, the final dimension of 'x' can have any shape.
# Therefore, the final two dimensions of special_shape are 1's.
special_shape = self.batch_shape.concatenate([1, 1])
bshape = array_ops.broadcast_static_shape(x.get_shape(), special_shape)
bshape = array_ops.broadcast_static_shape(x.shape, special_shape)
if special_shape.is_fully_defined():
# bshape.is_fully_defined iff special_shape.is_fully_defined.
if bshape == x.get_shape():
if bshape == x.shape:
return x
# Use the built in broadcasting of addition.
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)