s/get_shape()/shape in linalg. This is the TF2 way.
PiperOrigin-RevId: 264514950
This commit is contained in:
parent
25aba97589
commit
c087e10678
@ -339,8 +339,8 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
|
||||
h = operator.convolution_kernel()
|
||||
c = operator.to_dense()
|
||||
|
||||
self.assertAllEqual((2, 3), h.get_shape())
|
||||
self.assertAllEqual((2, 3, 3), c.get_shape())
|
||||
self.assertAllEqual((2, 3), h.shape)
|
||||
self.assertAllEqual((2, 3, 3), c.shape)
|
||||
self.assertAllClose(h.eval(), self.evaluate(c)[:, :, 0])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
|
@ -145,16 +145,16 @@ class LinearOperatorDiagTest(
|
||||
# Create a batch matrix with the broadcast shape of operator.
|
||||
diag_broadcast = array_ops.concat((diag, diag), 1)
|
||||
mat = array_ops.matrix_diag(diag_broadcast)
|
||||
self.assertAllEqual((2, 2, 3, 3), mat.get_shape()) # being pedantic.
|
||||
self.assertAllEqual((2, 2, 3, 3), mat.shape) # being pedantic.
|
||||
|
||||
operator_matmul = operator.matmul(x)
|
||||
mat_matmul = math_ops.matmul(mat, x)
|
||||
self.assertAllEqual(operator_matmul.get_shape(), mat_matmul.get_shape())
|
||||
self.assertAllEqual(operator_matmul.shape, mat_matmul.shape)
|
||||
self.assertAllClose(*self.evaluate([operator_matmul, mat_matmul]))
|
||||
|
||||
operator_solve = operator.solve(x)
|
||||
mat_solve = linalg_ops.matrix_solve(mat, x)
|
||||
self.assertAllEqual(operator_solve.get_shape(), mat_solve.get_shape())
|
||||
self.assertAllEqual(operator_solve.shape, mat_solve.shape)
|
||||
self.assertAllClose(*self.evaluate([operator_solve, mat_solve]))
|
||||
|
||||
def test_diag_matmul(self):
|
||||
|
@ -173,7 +173,7 @@ class LinearOperatorIdentityTest(
|
||||
operator_matmul = operator.matmul(x)
|
||||
expected = x
|
||||
|
||||
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
|
||||
self.assertAllEqual(operator_matmul.shape, expected.shape)
|
||||
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
|
||||
|
||||
def test_default_batch_shape_broadcasts_with_everything_dynamic(self):
|
||||
@ -207,7 +207,7 @@ class LinearOperatorIdentityTest(
|
||||
expected = x + zeros
|
||||
|
||||
operator_matmul = operator.matmul(x)
|
||||
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
|
||||
self.assertAllEqual(operator_matmul.shape, expected.shape)
|
||||
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
|
||||
|
||||
def test_broadcast_matmul_dynamic_shapes(self):
|
||||
@ -423,13 +423,13 @@ class LinearOperatorScaledIdentityTest(
|
||||
# Test matmul
|
||||
expected = x * 2.2 + zeros
|
||||
operator_matmul = operator.matmul(x)
|
||||
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
|
||||
self.assertAllEqual(operator_matmul.shape, expected.shape)
|
||||
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
|
||||
|
||||
# Test solve
|
||||
expected = x / 2.2 + zeros
|
||||
operator_solve = operator.solve(x)
|
||||
self.assertAllEqual(operator_solve.get_shape(), expected.get_shape())
|
||||
self.assertAllEqual(operator_solve.shape, expected.shape)
|
||||
self.assertAllClose(*self.evaluate([operator_solve, expected]))
|
||||
|
||||
def test_broadcast_matmul_and_solve_scalar_scale_multiplier(self):
|
||||
@ -449,13 +449,13 @@ class LinearOperatorScaledIdentityTest(
|
||||
# Test matmul
|
||||
expected = x * 2.2
|
||||
operator_matmul = operator.matmul(x)
|
||||
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
|
||||
self.assertAllEqual(operator_matmul.shape, expected.shape)
|
||||
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
|
||||
|
||||
# Test solve
|
||||
expected = x / 2.2
|
||||
operator_solve = operator.solve(x)
|
||||
self.assertAllEqual(operator_solve.get_shape(), expected.get_shape())
|
||||
self.assertAllEqual(operator_solve.shape, expected.shape)
|
||||
self.assertAllClose(*self.evaluate([operator_solve, expected]))
|
||||
|
||||
def test_is_x_flags(self):
|
||||
|
@ -80,7 +80,7 @@ class LinearOperatorMatmulSolve(linalg.LinearOperator):
|
||||
is_square=is_square)
|
||||
|
||||
def _shape(self):
|
||||
return self._matrix.get_shape()
|
||||
return self._matrix.shape
|
||||
|
||||
def _shape_tensor(self):
|
||||
return array_ops.shape(self._matrix)
|
||||
@ -136,7 +136,7 @@ class LinearOperatorTest(test.TestCase):
|
||||
operator = LinearOperatorMatmulSolve(matrix)
|
||||
with self.cached_session():
|
||||
operator_dense = operator.to_dense()
|
||||
self.assertAllEqual((2, 3, 4), operator_dense.get_shape())
|
||||
self.assertAllEqual((2, 3, 4), operator_dense.shape)
|
||||
self.assertAllClose(matrix, self.evaluate(operator_dense))
|
||||
|
||||
def test_generic_to_dense_method_non_square_matrix_tensor(self):
|
||||
@ -152,7 +152,7 @@ class LinearOperatorTest(test.TestCase):
|
||||
x = [1., 1.]
|
||||
with self.cached_session():
|
||||
y = operator.matvec(x)
|
||||
self.assertAllEqual((2,), y.get_shape())
|
||||
self.assertAllEqual((2,), y.shape)
|
||||
self.assertAllClose([1., 2.], self.evaluate(y))
|
||||
|
||||
def test_solvevec(self):
|
||||
@ -161,7 +161,7 @@ class LinearOperatorTest(test.TestCase):
|
||||
y = [1., 1.]
|
||||
with self.cached_session():
|
||||
x = operator.solvevec(y)
|
||||
self.assertAllEqual((2,), x.get_shape())
|
||||
self.assertAllEqual((2,), x.shape)
|
||||
self.assertAllClose([1., 1 / 2.], self.evaluate(x))
|
||||
|
||||
def test_is_square_set_to_true_for_square_static_shapes(self):
|
||||
|
@ -115,8 +115,8 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
|
||||
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.shape)
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.shape)
|
||||
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
@ -133,8 +133,8 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
|
||||
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.shape)
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.shape)
|
||||
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
@ -197,7 +197,7 @@ class CholeskySolveWithBroadcastTest(test.TestCase):
|
||||
chol_broadcast = chol + np.zeros((2, 1, 1))
|
||||
|
||||
result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.get_shape())
|
||||
self.assertAllEqual((2, 3, 7), result.shape)
|
||||
expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@ -227,7 +227,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
|
||||
rhs_broadcast = rhs + np.zeros((2, 1, 1))
|
||||
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.get_shape())
|
||||
self.assertAllEqual((2, 3, 7), result.shape)
|
||||
expected = linalg_ops.matrix_solve(matrix, rhs_broadcast)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@ -244,7 +244,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
|
||||
matrix_broadcast = matrix + np.zeros((2, 1, 1))
|
||||
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
self.assertAllEqual((2, 3, 2), result.shape)
|
||||
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@ -282,7 +282,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
|
||||
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(
|
||||
matrix, rhs, adjoint=True)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
self.assertAllEqual((2, 3, 2), result.shape)
|
||||
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs, adjoint=True)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@ -313,7 +313,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
|
||||
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.get_shape())
|
||||
self.assertAllEqual((2, 3, 7), result.shape)
|
||||
expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@ -331,7 +331,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
|
||||
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
self.assertAllEqual((2, 3, 2), result.shape)
|
||||
expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@ -349,7 +349,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
|
||||
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs, adjoint=True)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
self.assertAllEqual((2, 3, 2), result.shape)
|
||||
expected = linalg_ops.matrix_triangular_solve(
|
||||
matrix_broadcast, rhs, adjoint=True)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
@ -137,8 +137,8 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
||||
"""Static check of spectrum. Then return `Tensor` version."""
|
||||
spectrum = ops.convert_to_tensor(spectrum, name="spectrum")
|
||||
|
||||
if spectrum.get_shape().ndims is not None:
|
||||
if spectrum.get_shape().ndims < self.block_depth:
|
||||
if spectrum.shape.ndims is not None:
|
||||
if spectrum.shape.ndims < self.block_depth:
|
||||
raise ValueError(
|
||||
"Argument spectrum must have at least %d dimensions. Found: %s" %
|
||||
(self.block_depth, spectrum))
|
||||
@ -183,7 +183,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
||||
|
||||
@property
|
||||
def block_shape(self):
|
||||
return self.spectrum.get_shape()[-self.block_depth:]
|
||||
return self.spectrum.shape[-self.block_depth:]
|
||||
|
||||
@property
|
||||
def spectrum(self):
|
||||
@ -207,11 +207,11 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
||||
|
||||
# Blockify: Blockfy trailing dimensions.
|
||||
# [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1]
|
||||
if (vec.get_shape().is_fully_defined() and
|
||||
if (vec.shape.is_fully_defined() and
|
||||
self.block_shape.is_fully_defined()):
|
||||
# vec_leading_shape = [m3, m0, m1],
|
||||
# the parts of vec that will not be blockified.
|
||||
vec_leading_shape = vec.get_shape()[:-1]
|
||||
vec_leading_shape = vec.shape[:-1]
|
||||
final_shape = vec_leading_shape.concatenate(self.block_shape)
|
||||
else:
|
||||
vec_leading_shape = array_ops.shape(vec)[:-1]
|
||||
@ -232,9 +232,9 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
||||
|
||||
# Un-blockify: Flatten block dimensions. Reshape
|
||||
# [v0, v1, v2, v3] --> [v0, v1, v2*v3].
|
||||
if vec.get_shape().is_fully_defined():
|
||||
if vec.shape.is_fully_defined():
|
||||
# vec_shape = [v0, v1, v2, v3]
|
||||
vec_shape = vec.get_shape().as_list()
|
||||
vec_shape = vec.shape.as_list()
|
||||
# vec_leading_shape = [v0, v1]
|
||||
vec_leading_shape = vec_shape[:-self.block_depth]
|
||||
# vec_block_shape = [v2, v3]
|
||||
@ -298,7 +298,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
||||
return math_ops.cast(h, self.dtype)
|
||||
|
||||
def _shape(self):
|
||||
s_shape = self._spectrum.get_shape()
|
||||
s_shape = self._spectrum.shape
|
||||
# Suppose spectrum.shape = [a, b, c, d]
|
||||
# block_depth = 2
|
||||
# Then:
|
||||
@ -471,8 +471,8 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
||||
|
||||
# Get shape of diag along with the axis over which to reduce the spectrum.
|
||||
# We will reduce the spectrum over all block indices.
|
||||
if self.spectrum.get_shape().is_fully_defined():
|
||||
spec_rank = self.spectrum.get_shape().ndims
|
||||
if self.spectrum.shape.is_fully_defined():
|
||||
spec_rank = self.spectrum.shape.ndims
|
||||
axis = np.arange(spec_rank - self.block_depth, spec_rank, dtype=np.int32)
|
||||
else:
|
||||
spec_rank = array_ops.rank(self.spectrum)
|
||||
|
@ -167,13 +167,13 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
|
||||
def _check_diag(self, diag):
|
||||
"""Static check of diag."""
|
||||
if diag.get_shape().ndims is not None and diag.get_shape().ndims < 1:
|
||||
if diag.shape.ndims is not None and diag.shape.ndims < 1:
|
||||
raise ValueError("Argument diag must have at least 1 dimension. "
|
||||
"Found: %s" % diag)
|
||||
|
||||
def _shape(self):
|
||||
# If d_shape = [5, 3], we return [5, 3, 3].
|
||||
d_shape = self._diag.get_shape()
|
||||
d_shape = self._diag.shape
|
||||
return d_shape.concatenate(d_shape[-1:])
|
||||
|
||||
def _shape_tensor(self):
|
||||
|
@ -166,13 +166,13 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
"Argument matrix must have dtype in %s. Found: %s"
|
||||
% (allowed_dtypes, dtype))
|
||||
|
||||
if matrix.get_shape().ndims is not None and matrix.get_shape().ndims < 2:
|
||||
if matrix.shape.ndims is not None and matrix.shape.ndims < 2:
|
||||
raise ValueError(
|
||||
"Argument matrix must have at least 2 dimensions. Found: %s"
|
||||
% matrix)
|
||||
|
||||
def _shape(self):
|
||||
return self._matrix.get_shape()
|
||||
return self._matrix.shape
|
||||
|
||||
def _shape_tensor(self):
|
||||
return array_ops.shape(self._matrix)
|
||||
|
@ -155,15 +155,15 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
|
||||
|
||||
def _check_reflection_axis(self, reflection_axis):
|
||||
"""Static check of reflection_axis."""
|
||||
if (reflection_axis.get_shape().ndims is not None and
|
||||
reflection_axis.get_shape().ndims < 1):
|
||||
if (reflection_axis.shape.ndims is not None and
|
||||
reflection_axis.shape.ndims < 1):
|
||||
raise ValueError(
|
||||
"Argument reflection_axis must have at least 1 dimension. "
|
||||
"Found: %s" % reflection_axis)
|
||||
|
||||
def _shape(self):
|
||||
# If d_shape = [5, 3], we return [5, 3, 3].
|
||||
d_shape = self._reflection_axis.get_shape()
|
||||
d_shape = self._reflection_axis.shape
|
||||
return d_shape.concatenate(d_shape[-1:])
|
||||
|
||||
def _shape_tensor(self):
|
||||
|
@ -333,10 +333,10 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
# Also, the final dimension of 'x' can have any shape.
|
||||
# Therefore, the final two dimensions of special_shape are 1's.
|
||||
special_shape = self.batch_shape.concatenate([1, 1])
|
||||
bshape = array_ops.broadcast_static_shape(x.get_shape(), special_shape)
|
||||
bshape = array_ops.broadcast_static_shape(x.shape, special_shape)
|
||||
if special_shape.is_fully_defined():
|
||||
# bshape.is_fully_defined iff special_shape.is_fully_defined.
|
||||
if bshape == x.get_shape():
|
||||
if bshape == x.shape:
|
||||
return x
|
||||
# Use the built in broadcasting of addition.
|
||||
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
|
||||
@ -628,7 +628,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
matrix_shape = tensor_shape.TensorShape((self._num_rows_static,
|
||||
self._num_rows_static))
|
||||
|
||||
batch_shape = self.multiplier.get_shape()
|
||||
batch_shape = self.multiplier.shape
|
||||
return batch_shape.concatenate(matrix_shape)
|
||||
|
||||
def _shape_tensor(self):
|
||||
|
@ -271,7 +271,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
"""Static check that shapes are compatible."""
|
||||
# Broadcast shape also checks that u and v are compatible.
|
||||
uv_shape = array_ops.broadcast_static_shape(
|
||||
self.u.get_shape(), self.v.get_shape())
|
||||
self.u.shape, self.v.shape)
|
||||
|
||||
batch_shape = array_ops.broadcast_static_shape(
|
||||
self.base_operator.batch_shape, uv_shape[:-2])
|
||||
@ -282,9 +282,9 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
|
||||
if self._diag_update is not None:
|
||||
tensor_shape.dimension_at_index(uv_shape, -1).assert_is_compatible_with(
|
||||
self._diag_update.get_shape()[-1])
|
||||
self._diag_update.shape[-1])
|
||||
array_ops.broadcast_static_shape(
|
||||
batch_shape, self._diag_update.get_shape()[:-1])
|
||||
batch_shape, self._diag_update.shape[:-1])
|
||||
|
||||
def _set_diag_operators(self, diag_update, is_diag_update_positive):
|
||||
"""Set attributes self._diag_update and self._diag_operator."""
|
||||
@ -335,7 +335,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
def _shape(self):
|
||||
batch_shape = array_ops.broadcast_static_shape(
|
||||
self.base_operator.batch_shape,
|
||||
self.u.get_shape()[:-2])
|
||||
self.u.shape[:-2])
|
||||
return batch_shape.concatenate(self.base_operator.shape[-2:])
|
||||
|
||||
def _shape_tensor(self):
|
||||
|
@ -160,7 +160,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
|
||||
def _check_tril(self, tril):
|
||||
"""Static check of the `tril` argument."""
|
||||
|
||||
if tril.get_shape().ndims is not None and tril.get_shape().ndims < 2:
|
||||
if tril.shape.ndims is not None and tril.shape.ndims < 2:
|
||||
raise ValueError(
|
||||
"Argument tril must have at least 2 dimensions. Found: %s"
|
||||
% tril)
|
||||
@ -174,7 +174,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
|
||||
return array_ops.matrix_diag_part(self._tril)
|
||||
|
||||
def _shape(self):
|
||||
return self._tril.get_shape()
|
||||
return self._tril.shape
|
||||
|
||||
def _shape_tensor(self):
|
||||
return array_ops.shape(self._tril)
|
||||
|
@ -272,7 +272,7 @@ def _test_to_dense(use_placeholder, shapes_info, dtype):
|
||||
shapes_info, dtype, use_placeholder=use_placeholder)
|
||||
op_dense = operator.to_dense()
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(shapes_info.shape, op_dense.get_shape())
|
||||
self.assertAllEqual(shapes_info.shape, op_dense.shape)
|
||||
op_dense_v, mat_v = sess.run([op_dense, mat])
|
||||
self.assertAC(op_dense_v, mat_v)
|
||||
return test_to_dense
|
||||
@ -286,7 +286,7 @@ def _test_det(use_placeholder, shapes_info, dtype):
|
||||
shapes_info, dtype, use_placeholder=use_placeholder)
|
||||
op_det = operator.determinant()
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(shapes_info.shape[:-2], op_det.get_shape())
|
||||
self.assertAllEqual(shapes_info.shape[:-2], op_det.shape)
|
||||
op_det_v, mat_det_v = sess.run(
|
||||
[op_det, linalg_ops.matrix_determinant(mat)])
|
||||
self.assertAC(op_det_v, mat_det_v)
|
||||
@ -303,7 +303,7 @@ def _test_log_abs_det(use_placeholder, shapes_info, dtype):
|
||||
_, mat_log_abs_det = linalg.slogdet(mat)
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(
|
||||
shapes_info.shape[:-2], op_log_abs_det.get_shape())
|
||||
shapes_info.shape[:-2], op_log_abs_det.shape)
|
||||
op_log_abs_det_v, mat_log_abs_det_v = sess.run(
|
||||
[op_log_abs_det, mat_log_abs_det])
|
||||
self.assertAC(op_log_abs_det_v, mat_log_abs_det_v)
|
||||
@ -340,8 +340,8 @@ def _test_matmul_base(
|
||||
op_matmul = operator.matmul(x, adjoint=adjoint)
|
||||
mat_matmul = math_ops.matmul(mat, x, adjoint_a=adjoint)
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(op_matmul.get_shape(),
|
||||
mat_matmul.get_shape())
|
||||
self.assertAllEqual(op_matmul.shape,
|
||||
mat_matmul.shape)
|
||||
op_matmul_v, mat_matmul_v = sess.run(
|
||||
[op_matmul, mat_matmul])
|
||||
self.assertAC(op_matmul_v, mat_matmul_v)
|
||||
@ -445,8 +445,8 @@ def _test_solve_base(
|
||||
mat_solve = linear_operator_util.matrix_solve_with_broadcast(
|
||||
mat, rhs, adjoint=adjoint)
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(op_solve.get_shape(),
|
||||
mat_solve.get_shape())
|
||||
self.assertAllEqual(op_solve.shape,
|
||||
mat_solve.shape)
|
||||
op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
|
||||
self.assertAC(op_solve_v, mat_solve_v)
|
||||
|
||||
@ -500,7 +500,7 @@ def _test_trace(use_placeholder, shapes_info, dtype):
|
||||
op_trace = operator.trace()
|
||||
mat_trace = math_ops.trace(mat)
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape())
|
||||
self.assertAllEqual(op_trace.shape, mat_trace.shape)
|
||||
op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace])
|
||||
self.assertAC(op_trace_v, mat_trace_v)
|
||||
return test_trace
|
||||
@ -515,7 +515,7 @@ def _test_add_to_tensor(use_placeholder, shapes_info, dtype):
|
||||
op_plus_2mat = operator.add_to_tensor(2 * mat)
|
||||
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(shapes_info.shape, op_plus_2mat.get_shape())
|
||||
self.assertAllEqual(shapes_info.shape, op_plus_2mat.shape)
|
||||
|
||||
op_plus_2mat_v, mat_v = sess.run([op_plus_2mat, mat])
|
||||
|
||||
@ -533,8 +533,8 @@ def _test_diag_part(use_placeholder, shapes_info, dtype):
|
||||
mat_diag_part = array_ops.matrix_diag_part(mat)
|
||||
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(mat_diag_part.get_shape(),
|
||||
op_diag_part.get_shape())
|
||||
self.assertAllEqual(mat_diag_part.shape,
|
||||
op_diag_part.shape)
|
||||
|
||||
op_diag_part_, mat_diag_part_ = sess.run(
|
||||
[op_diag_part, mat_diag_part])
|
||||
|
@ -168,12 +168,12 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
|
||||
def _check_row_col(self, row, col):
|
||||
"""Static check of row and column."""
|
||||
for name, tensor in [["row", row], ["col", col]]:
|
||||
if tensor.get_shape().ndims is not None and tensor.get_shape().ndims < 1:
|
||||
if tensor.shape.ndims is not None and tensor.shape.ndims < 1:
|
||||
raise ValueError("Argument {} must have at least 1 dimension. "
|
||||
"Found: {}".format(name, tensor))
|
||||
|
||||
if row.get_shape()[-1] is not None and col.get_shape()[-1] is not None:
|
||||
if row.get_shape()[-1] != col.get_shape()[-1]:
|
||||
if row.shape[-1] is not None and col.shape[-1] is not None:
|
||||
if row.shape[-1] != col.shape[-1]:
|
||||
raise ValueError(
|
||||
"Expected square matrix, got row and col with mismatched "
|
||||
"dimensions.")
|
||||
|
@ -239,7 +239,7 @@ def assert_compatible_matrix_dimensions(operator, x):
|
||||
|
||||
def assert_is_batch_matrix(tensor):
|
||||
"""Static assert that `tensor` has rank `2` or higher."""
|
||||
sh = tensor.get_shape()
|
||||
sh = tensor.shape
|
||||
if sh.ndims is not None and sh.ndims < 2:
|
||||
raise ValueError(
|
||||
"Expected [batch] matrix to have at least two dimensions. Found: "
|
||||
@ -327,14 +327,14 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
|
||||
# x.shape = [2, j, k] (batch shape = [2])
|
||||
# y.shape = [3, 1, l, m] (batch shape = [3, 1])
|
||||
# ==> bcast_batch_shape = [3, 2]
|
||||
bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
|
||||
bcast_batch_shape = batch_matrices[0].shape[:-2]
|
||||
for mat in batch_matrices[1:]:
|
||||
bcast_batch_shape = array_ops.broadcast_static_shape(
|
||||
bcast_batch_shape,
|
||||
mat.get_shape()[:-2])
|
||||
mat.shape[:-2])
|
||||
if bcast_batch_shape.is_fully_defined():
|
||||
for i, mat in enumerate(batch_matrices):
|
||||
if mat.get_shape()[:-2] != bcast_batch_shape:
|
||||
if mat.shape[:-2] != bcast_batch_shape:
|
||||
bcast_shape = array_ops.concat(
|
||||
[bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0)
|
||||
batch_matrices[i] = array_ops.broadcast_to(mat, bcast_shape)
|
||||
|
@ -277,10 +277,10 @@ class LinearOperatorZeros(linear_operator.LinearOperator):
|
||||
# Also, the final dimension of 'x' can have any shape.
|
||||
# Therefore, the final two dimensions of special_shape are 1's.
|
||||
special_shape = self.batch_shape.concatenate([1, 1])
|
||||
bshape = array_ops.broadcast_static_shape(x.get_shape(), special_shape)
|
||||
bshape = array_ops.broadcast_static_shape(x.shape, special_shape)
|
||||
if special_shape.is_fully_defined():
|
||||
# bshape.is_fully_defined iff special_shape.is_fully_defined.
|
||||
if bshape == x.get_shape():
|
||||
if bshape == x.shape:
|
||||
return x
|
||||
# Use the built in broadcasting of addition.
|
||||
zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user