Deprecate tf.batch_matmul and replace with equivalent calls to tf.matmul that now supports adjoint and batch matmul.
CL created by: replace_string \ batch_matmul\\\( \ matmul\( plus some manual edits, mostly s/adj_x/adjoint_a/ s/adj_y/adjoint_b/. Change: 139536034
This commit is contained in:
parent
9e650f5a20
commit
2eabf986b0
@ -193,7 +193,7 @@ class MultivariateNormalCholeskyTest(tf.test.TestCase):
|
||||
mat = self._rng.rand(*shape)
|
||||
chol = distributions.matrix_diag_transform(mat, transform=tf.nn.softplus)
|
||||
chol = tf.matrix_band_part(chol, -1, 0)
|
||||
sigma = tf.batch_matmul(chol, chol, adj_y=True)
|
||||
sigma = tf.matmul(chol, chol, adjoint_b=True)
|
||||
return chol.eval(), sigma.eval()
|
||||
|
||||
def testNonmatchingMuSigmaFailsStatic(self):
|
||||
@ -391,7 +391,7 @@ class MultivariateNormalFullTest(tf.test.TestCase):
|
||||
# This ensures sigma is positive def.
|
||||
mat_shape = batch_shape + event_shape + event_shape
|
||||
mat = self._rng.randn(*mat_shape)
|
||||
sigma = tf.batch_matmul(mat, mat, adj_y=True).eval()
|
||||
sigma = tf.matmul(mat, mat, adjoint_b=True).eval()
|
||||
|
||||
mu_shape = batch_shape + event_shape
|
||||
mu = self._rng.randn(*mu_shape)
|
||||
|
@ -84,7 +84,7 @@ class OperatorPDCholeskyTest(tf.test.TestCase):
|
||||
operator = operator_pd_cholesky.OperatorPDCholesky(chol)
|
||||
|
||||
sqrt_operator_times_x = operator.sqrt_matmul(x)
|
||||
expected = tf.batch_matmul(chol, x)
|
||||
expected = tf.matmul(chol, x)
|
||||
|
||||
self.assertEqual(expected.get_shape(),
|
||||
sqrt_operator_times_x.get_shape())
|
||||
@ -102,7 +102,7 @@ class OperatorPDCholeskyTest(tf.test.TestCase):
|
||||
operator = operator_pd_cholesky.OperatorPDCholesky(chol)
|
||||
|
||||
sqrt_operator_times_x = operator.sqrt_matmul(x)
|
||||
expected = tf.batch_matmul(chol, x)
|
||||
expected = tf.matmul(chol, x)
|
||||
|
||||
self.assertEqual(expected.get_shape(),
|
||||
sqrt_operator_times_x.get_shape())
|
||||
@ -121,7 +121,7 @@ class OperatorPDCholeskyTest(tf.test.TestCase):
|
||||
|
||||
sqrt_operator_times_x = operator.sqrt_matmul(x, transpose_x=True)
|
||||
# tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
|
||||
expected = tf.batch_matmul(chol, x, adj_y=True)
|
||||
expected = tf.matmul(chol, x, adjoint_b=True)
|
||||
|
||||
self.assertEqual(expected.get_shape(),
|
||||
sqrt_operator_times_x.get_shape())
|
||||
@ -135,11 +135,11 @@ class OperatorPDCholeskyTest(tf.test.TestCase):
|
||||
x = self._rng.rand(*x_shape)
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
matrix = tf.batch_matmul(chol, chol, adj_y=True)
|
||||
matrix = tf.matmul(chol, chol, adjoint_b=True)
|
||||
|
||||
operator = operator_pd_cholesky.OperatorPDCholesky(chol)
|
||||
|
||||
expected = tf.batch_matmul(matrix, x)
|
||||
expected = tf.matmul(matrix, x)
|
||||
|
||||
self.assertEqual(expected.get_shape(), operator.matmul(x).get_shape())
|
||||
self.assertAllClose(expected.eval(), operator.matmul(x).eval())
|
||||
@ -152,11 +152,11 @@ class OperatorPDCholeskyTest(tf.test.TestCase):
|
||||
x = self._rng.rand(*x_shape)
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
matrix = tf.batch_matmul(chol, chol, adj_y=True)
|
||||
matrix = tf.matmul(chol, chol, adjoint_b=True)
|
||||
|
||||
operator = operator_pd_cholesky.OperatorPDCholesky(chol)
|
||||
|
||||
expected = tf.batch_matmul(matrix, x)
|
||||
expected = tf.matmul(matrix, x)
|
||||
|
||||
self.assertEqual(expected.get_shape(), operator.matmul(x).get_shape())
|
||||
self.assertAllClose(expected.eval(), operator.matmul(x).eval())
|
||||
@ -169,13 +169,13 @@ class OperatorPDCholeskyTest(tf.test.TestCase):
|
||||
x = self._rng.rand(*x_shape)
|
||||
chol_shape = batch_shape + (k, k)
|
||||
chol = self._random_cholesky_array(chol_shape)
|
||||
matrix = tf.batch_matmul(chol, chol, adj_y=True)
|
||||
matrix = tf.matmul(chol, chol, adjoint_b=True)
|
||||
|
||||
operator = operator_pd_cholesky.OperatorPDCholesky(chol)
|
||||
operator_times_x = operator.matmul(x, transpose_x=True)
|
||||
|
||||
# tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
|
||||
expected = tf.batch_matmul(matrix, x, adj_y=True)
|
||||
expected = tf.matmul(matrix, x, adjoint_b=True)
|
||||
|
||||
self.assertEqual(expected.get_shape(), operator_times_x.get_shape())
|
||||
self.assertAllClose(expected.eval(), operator_times_x.eval())
|
||||
|
@ -32,7 +32,7 @@ class OperatorPDFullTest(tf.test.TestCase):
|
||||
|
||||
def _random_positive_def_array(self, *shape):
|
||||
matrix = self._rng.rand(*shape)
|
||||
return tf.batch_matmul(matrix, matrix, adj_y=True).eval()
|
||||
return tf.matmul(matrix, matrix, adjoint_b=True).eval()
|
||||
|
||||
def testPositiveDefiniteMatrixDoesntRaise(self):
|
||||
with self.test_session():
|
||||
|
@ -75,7 +75,7 @@ class OperatorSolve(OperatorShape):
|
||||
"""Operator implements .solve."""
|
||||
|
||||
def __init__(self, chol):
|
||||
self._pos_def_matrix = tf.batch_matmul(chol, chol, adj_y=True)
|
||||
self._pos_def_matrix = tf.matmul(chol, chol, adjoint_b=True)
|
||||
super(OperatorSolve, self).__init__(chol.shape)
|
||||
|
||||
def _solve(self, rhs):
|
||||
|
@ -38,7 +38,7 @@ class OperatorPDSqrtVDVTUpdateTest(
|
||||
def _random_pd_matrix(self, shape):
|
||||
# With probability 1 this is positive definite.
|
||||
sqrt = self._rng.randn(*shape)
|
||||
mat = tf.batch_matmul(sqrt, sqrt, adj_y=True)
|
||||
mat = tf.matmul(sqrt, sqrt, adjoint_b=True)
|
||||
return mat.eval()
|
||||
|
||||
def _random_v_and_diag(self, mat_shape, v_matrix_rank):
|
||||
@ -67,11 +67,11 @@ class OperatorPDSqrtVDVTUpdateTest(
|
||||
diag_vt = tf.matrix_transpose(v)
|
||||
else:
|
||||
diag_mat = tf.matrix_diag(diag)
|
||||
diag_vt = tf.batch_matmul(diag_mat, v, adj_y=True)
|
||||
diag_vt = tf.matmul(diag_mat, v, adjoint_b=True)
|
||||
|
||||
v_diag_vt = tf.batch_matmul(v, diag_vt)
|
||||
v_diag_vt = tf.matmul(v, diag_vt)
|
||||
sqrt = mat + v_diag_vt
|
||||
a = tf.batch_matmul(sqrt, sqrt, adj_y=True)
|
||||
a = tf.matmul(sqrt, sqrt, adjoint_b=True)
|
||||
return a.eval()
|
||||
|
||||
def _build_operator_and_mat(self, batch_shape, k, dtype=np.float64):
|
||||
|
@ -1394,7 +1394,7 @@ class ScaleAndShift(Bijector):
|
||||
|
||||
def _forward(self, x):
|
||||
x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x)
|
||||
x = math_ops.batch_matmul(self.scale, x)
|
||||
x = math_ops.matmul(self.scale, x)
|
||||
x = self.shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape)
|
||||
x += self.shift
|
||||
return x
|
||||
@ -1776,7 +1776,7 @@ class CholeskyOuterProduct(Bijector):
|
||||
x = control_flow_ops.with_dependencies([is_matrix, is_square], x)
|
||||
# For safety, explicitly zero-out the upper triangular part.
|
||||
x = array_ops.matrix_band_part(x, -1, 0)
|
||||
return math_ops.batch_matmul(x, x, adj_y=True)
|
||||
return math_ops.matmul(x, x, adjoint_b=True)
|
||||
|
||||
def _inverse_and_inverse_log_det_jacobian(self, y):
|
||||
x = (math_ops.sqrt(y) if self._static_event_ndims == 0
|
||||
@ -1855,8 +1855,7 @@ class CholeskyOuterProduct(Bijector):
|
||||
dim=1)
|
||||
|
||||
sum_weighted_log_diag = array_ops.squeeze(
|
||||
math_ops.batch_matmul(math_ops.log(diag), exponents),
|
||||
squeeze_dims=-1)
|
||||
math_ops.matmul(math_ops.log(diag), exponents), squeeze_dims=-1)
|
||||
fldj = p * math.log(2.) + sum_weighted_log_diag
|
||||
|
||||
if x.get_shape().ndims is not None:
|
||||
|
@ -216,9 +216,11 @@ class Dirichlet(distribution.Distribution):
|
||||
def _variance(self):
|
||||
scale = self.alpha_sum * math_ops.sqrt(1. + self.alpha_sum)
|
||||
alpha = self.alpha / scale
|
||||
outer_prod = -math_ops.batch_matmul(
|
||||
array_ops.expand_dims(alpha, dim=-1), # column
|
||||
array_ops.expand_dims(alpha, dim=-2)) # row
|
||||
outer_prod = -math_ops.matmul(
|
||||
array_ops.expand_dims(
|
||||
alpha, dim=-1), # column
|
||||
array_ops.expand_dims(
|
||||
alpha, dim=-2)) # row
|
||||
return array_ops.matrix_set_diag(outer_prod,
|
||||
alpha * (self.alpha_sum / scale - alpha))
|
||||
|
||||
|
@ -245,7 +245,7 @@ class DirichletMultinomial(distribution.Distribution):
|
||||
def _variance(self):
|
||||
alpha_sum = array_ops.expand_dims(self.alpha_sum, -1)
|
||||
normalized_alpha = self.alpha / alpha_sum
|
||||
variance = -math_ops.batch_matmul(
|
||||
variance = -math_ops.matmul(
|
||||
array_ops.expand_dims(normalized_alpha, -1),
|
||||
array_ops.expand_dims(normalized_alpha, -2))
|
||||
variance = array_ops.matrix_set_diag(variance, normalized_alpha *
|
||||
|
@ -223,9 +223,8 @@ class Multinomial(distribution.Distribution):
|
||||
|
||||
def _variance(self):
|
||||
p = self.p * array_ops.expand_dims(array_ops.ones_like(self.n), -1)
|
||||
outer_prod = math_ops.batch_matmul(
|
||||
array_ops.expand_dims(self._mean_val, -1),
|
||||
array_ops.expand_dims(p, -2))
|
||||
outer_prod = math_ops.matmul(
|
||||
array_ops.expand_dims(self._mean_val, -1), array_ops.expand_dims(p, -2))
|
||||
return array_ops.matrix_set_diag(-outer_prod,
|
||||
self._mean_val - self._mean_val * p)
|
||||
|
||||
|
@ -437,7 +437,7 @@ class OperatorPDBase(object):
|
||||
name: A name to give this `Op`.
|
||||
|
||||
Returns:
|
||||
A result equivalent to `tf.batch_matmul(self.to_dense(), x)`.
|
||||
A result equivalent to `tf.matmul(self.to_dense(), x)`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.name_scope(name, values=[x] + self.inputs):
|
||||
@ -471,7 +471,7 @@ class OperatorPDBase(object):
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
A result equivalent to `tf.batch_matmul(self.sqrt_to_dense(), x)`.
|
||||
A result equivalent to `tf.matmul(self.sqrt_to_dense(), x)`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.name_scope(name, values=[x] + self.inputs):
|
||||
|
@ -127,21 +127,21 @@ class OperatorPDCholesky(operator_pd.OperatorPDBase):
|
||||
return math_ops.matmul(chol, chol_times_x)
|
||||
|
||||
def _batch_matmul(self, x, transpose_x=False):
|
||||
# tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
|
||||
# tf.matmul is defined x * y, so "y" is on the right, not "x".
|
||||
chol = array_ops.matrix_band_part(self._chol, -1, 0)
|
||||
chol_times_x = math_ops.batch_matmul(
|
||||
chol, x, adj_x=True, adj_y=transpose_x)
|
||||
return math_ops.batch_matmul(chol, chol_times_x)
|
||||
chol_times_x = math_ops.matmul(
|
||||
chol, x, adjoint_a=True, adjoint_b=transpose_x)
|
||||
return math_ops.matmul(chol, chol_times_x)
|
||||
|
||||
def _sqrt_matmul(self, x, transpose_x=False):
|
||||
chol = array_ops.matrix_band_part(self._chol, -1, 0)
|
||||
# tf.matmul is defined a * b
|
||||
return math_ops.matmul(chol, x, transpose_b=transpose_x)
|
||||
return math_ops.matmul(chol, x, adjoint_b=transpose_x)
|
||||
|
||||
def _batch_sqrt_matmul(self, x, transpose_x=False):
|
||||
chol = array_ops.matrix_band_part(self._chol, -1, 0)
|
||||
# tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
|
||||
return math_ops.batch_matmul(chol, x, adj_y=transpose_x)
|
||||
return math_ops.matmul(chol, x, adjoint_b=transpose_x)
|
||||
|
||||
def _batch_solve(self, rhs):
|
||||
return linalg_ops.cholesky_solve(self._chol, rhs)
|
||||
@ -181,4 +181,4 @@ class OperatorPDCholesky(operator_pd.OperatorPDBase):
|
||||
|
||||
def _to_dense(self):
|
||||
chol = array_ops.matrix_band_part(self._chol, -1, 0)
|
||||
return math_ops.batch_matmul(chol, chol, adj_y=True)
|
||||
return math_ops.matmul(chol, chol, adjoint_b=True)
|
||||
|
@ -319,10 +319,7 @@ class OperatorPDSqrtVDVTUpdate(operator_pd.OperatorPDBase):
|
||||
# M^{-1} V
|
||||
minv_v = self._operator.solve(self._v)
|
||||
# V^T M^{-1} V
|
||||
if batch_mode:
|
||||
vt_minv_v = math_ops.batch_matmul(self._v, minv_v, adj_x=True)
|
||||
else:
|
||||
vt_minv_v = math_ops.matmul(self._v, minv_v, transpose_a=True)
|
||||
vt_minv_v = math_ops.matmul(self._v, minv_v, adjoint_a=True)
|
||||
|
||||
# D^{-1} + V^T M^{-1} V
|
||||
capacitance = self._diag_inv_operator.add_to_tensor(vt_minv_v)
|
||||
@ -360,14 +357,14 @@ class OperatorPDSqrtVDVTUpdate(operator_pd.OperatorPDBase):
|
||||
v = self._v
|
||||
m = self._operator
|
||||
d = self._diag_operator
|
||||
# The operators call the appropriate matmul/batch_matmul automatically. We
|
||||
# cannot override.
|
||||
# batch_matmul is defined as: x * y, so adj_x and adj_y are the ways to
|
||||
# transpose the left and right.
|
||||
# The operators call the appropriate matmul/batch_matmul automatically.
|
||||
# We cannot override.
|
||||
# batch_matmul is defined as: x * y, so adjoint_a and adjoint_b are the
|
||||
# ways to transpose the left and right.
|
||||
mx = m.matmul(x, transpose_x=transpose_x)
|
||||
vt_x = math_ops.batch_matmul(v, x, adj_x=True, adj_y=transpose_x)
|
||||
vt_x = math_ops.matmul(v, x, adjoint_a=True, adjoint_b=transpose_x)
|
||||
d_vt_x = d.matmul(vt_x)
|
||||
v_d_vt_x = math_ops.batch_matmul(v, d_vt_x)
|
||||
v_d_vt_x = math_ops.matmul(v, d_vt_x)
|
||||
|
||||
return mx + v_d_vt_x
|
||||
|
||||
@ -444,11 +441,11 @@ class OperatorPDSqrtVDVTUpdate(operator_pd.OperatorPDBase):
|
||||
# M^{-1} rhs
|
||||
minv_rhs = m.solve(rhs)
|
||||
# V^T M^{-1} rhs
|
||||
vt_minv_rhs = math_ops.batch_matmul(v, minv_rhs, adj_x=True)
|
||||
vt_minv_rhs = math_ops.matmul(v, minv_rhs, adjoint_a=True)
|
||||
# C^{-1} V^T M^{-1} rhs
|
||||
cinv_vt_minv_rhs = linalg_ops.cholesky_solve(cchol, vt_minv_rhs)
|
||||
# V C^{-1} V^T M^{-1} rhs
|
||||
v_cinv_vt_minv_rhs = math_ops.batch_matmul(v, cinv_vt_minv_rhs)
|
||||
v_cinv_vt_minv_rhs = math_ops.matmul(v, cinv_vt_minv_rhs)
|
||||
# M^{-1} V C^{-1} V^T M^{-1} rhs
|
||||
minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs)
|
||||
|
||||
@ -457,7 +454,7 @@ class OperatorPDSqrtVDVTUpdate(operator_pd.OperatorPDBase):
|
||||
|
||||
def _to_dense(self):
|
||||
sqrt = self.sqrt_to_dense()
|
||||
return math_ops.batch_matmul(sqrt, sqrt, adj_y=True)
|
||||
return math_ops.matmul(sqrt, sqrt, adjoint_b=True)
|
||||
|
||||
def _sqrt_to_dense(self):
|
||||
v = self._v
|
||||
@ -467,6 +464,6 @@ class OperatorPDSqrtVDVTUpdate(operator_pd.OperatorPDBase):
|
||||
d_vt = d.matmul(v, transpose_x=True)
|
||||
# Batch op won't be efficient for singletons. Currently we don't break
|
||||
# to_dense into batch/singleton methods.
|
||||
v_d_vt = math_ops.batch_matmul(v, d_vt)
|
||||
v_d_vt = math_ops.matmul(v, d_vt)
|
||||
m_plus_v_d_vt = m.to_dense() + v_d_vt
|
||||
return m_plus_v_d_vt
|
||||
|
@ -87,9 +87,7 @@ class OperatorPDDerivedClassTest(tf.test.TestCase):
|
||||
self.assertEqual(mat.shape, sqrt.get_shape())
|
||||
# Square roots are not unique, but SS^T should equal mat. In this
|
||||
# case however, we should have S = S^T.
|
||||
self._compare_results(
|
||||
expected=mat,
|
||||
actual=tf.batch_matmul(sqrt, sqrt))
|
||||
self._compare_results(expected=mat, actual=tf.matmul(sqrt, sqrt))
|
||||
|
||||
def testDeterminants(self):
|
||||
with self.test_session():
|
||||
@ -111,8 +109,7 @@ class OperatorPDDerivedClassTest(tf.test.TestCase):
|
||||
x = self._rng.randn(*(batch_shape + (k, 5)))
|
||||
|
||||
self._compare_results(
|
||||
expected=tf.batch_matmul(mat, x).eval(),
|
||||
actual=operator.matmul(x))
|
||||
expected=tf.matmul(mat, x).eval(), actual=operator.matmul(x))
|
||||
|
||||
def testSqrtMatmul(self):
|
||||
# Square roots are not unique, but we should have SS^T x = Ax, and in our
|
||||
@ -126,7 +123,7 @@ class OperatorPDDerivedClassTest(tf.test.TestCase):
|
||||
x = self._rng.randn(*(batch_shape + (k, 5)))
|
||||
|
||||
self._compare_results(
|
||||
expected=tf.batch_matmul(mat, x).eval(),
|
||||
expected=tf.matmul(mat, x).eval(),
|
||||
actual=operator.sqrt_matmul(operator.sqrt_matmul(x)))
|
||||
|
||||
def testSolve(self):
|
||||
|
@ -245,7 +245,7 @@ class _WishartOperatorPD(distribution.Distribution):
|
||||
|
||||
if not self.cholesky_input_output_matrices:
|
||||
# Complexity: O(nbk^3)
|
||||
x = math_ops.batch_matmul(x, x, adj_y=True)
|
||||
x = math_ops.matmul(x, x, adjoint_b=True)
|
||||
|
||||
return x
|
||||
|
||||
@ -353,7 +353,7 @@ class _WishartOperatorPD(distribution.Distribution):
|
||||
def _variance(self):
|
||||
x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense()
|
||||
d = array_ops.expand_dims(array_ops.matrix_diag_part(x), -1)
|
||||
v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True)
|
||||
v = math_ops.square(x) + math_ops.matmul(d, d, adjoint_b=True)
|
||||
if self.cholesky_input_output_matrices:
|
||||
return linalg_ops.cholesky(v)
|
||||
return v
|
||||
|
Loading…
Reference in New Issue
Block a user