1054 lines
43 KiB
Python
1054 lines
43 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Gradients for operators defined in linalg_ops.py.
|
|
|
|
Useful reference for derivative formulas is (Mike Giles, 2008).
|
|
|
|
Ionescu et al. (2015) provide a detailed derivation of formulas for
|
|
backpropagating through spectral layers (SVD and Eig).
|
|
|
|
References:
|
|
An extended collection of matrix derivative results for
|
|
forward and reverse mode automatic differentiation:
|
|
[Mike Giles, 2008]
|
|
(https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124)
|
|
([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf))
|
|
Matrix Backpropagation for Deep Networks with Structured Layers
|
|
[Ionescu et al., 2015]
|
|
(https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html)
|
|
([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf))
|
|
Training Deep Networks with Structured Layers by Matrix Backpropagation:
|
|
[Ionescu et al., 2015](https://arxiv.org/abs/1509.07838)
|
|
([pdf](https://arxiv.org/pdf/1509.07838.pdf))
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import gen_linalg_ops
|
|
from tensorflow.python.ops import linalg_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops.linalg import linalg_impl as _linalg
|
|
|
|
|
|
@ops.RegisterGradient("MatrixInverse")
|
|
def _MatrixInverseGrad(op, grad):
|
|
"""Gradient for MatrixInverse."""
|
|
ainv = op.outputs[0]
|
|
return -math_ops.matmul(
|
|
ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True)
|
|
|
|
|
|
@ops.RegisterGradient("Einsum")
|
|
def _EinsumGrad(op, grad):
|
|
"""Gradient for Einsum."""
|
|
ellipsis = "..."
|
|
|
|
def _GetAxisFromLabel(subscripts, label):
|
|
"""Returns the axis (possibly negative) corresponding to a label.
|
|
|
|
Returns the axis index of the axis label if it is before an ellipsis (or if
|
|
the ellipsis is not present), and the negative index if it occurs after the
|
|
ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`.
|
|
|
|
For multiple occurrences, returns the leftmost one. If not found, returns
|
|
None.
|
|
|
|
Args:
|
|
subscripts: A string denoting the einsum subscript (e.g. `ab...cd`)
|
|
label: The single character axis label.
|
|
"""
|
|
splits = subscripts.split(ellipsis)
|
|
index = splits[0].find(label)
|
|
if index != -1:
|
|
return index
|
|
if len(splits) < 2:
|
|
return None
|
|
index = splits[1].find(label)
|
|
if index != -1:
|
|
return index - len(splits[1])
|
|
return None
|
|
|
|
def _GetBcastSubshape(subscripts):
|
|
"""Returns a tuple denoting the slice mapping to ellipsis.
|
|
|
|
For a given subscript, returns a tuple (start, end) denoting the start
|
|
axis index and the (negative) end axis index respectively. For any input
|
|
Tensor `x` described by the subscript, `x[start:end]` would be the slice
|
|
represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`.
|
|
|
|
If ellipsis is not present in `subscripts`, returns `(0, 0)`.
|
|
|
|
Args:
|
|
subscripts: A string denoting the einsum subscript.
|
|
"""
|
|
start = subscripts.find(ellipsis)
|
|
if start == -1:
|
|
return 0, 0
|
|
remaining = len(subscripts) - (start + len(ellipsis))
|
|
end = -remaining if remaining > 0 else None
|
|
return start, end
|
|
|
|
def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts):
|
|
"""Returns reduced subscripts and their corresponding dimensions and axes.
|
|
|
|
Given a set of axis labels, returns their concatenated subscript, their
|
|
corresponding dimensions from input_shape, and their corresponding axes.
|
|
Note that the concatenated subscript `reduced_subs` may have axis labels
|
|
from `reduced_label_set` in any order. For example, for the reduced label
|
|
set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns
|
|
subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`.
|
|
|
|
Args:
|
|
reduced_label_set: Set of axis labels which appear in `subscripts`.
|
|
input_shape: A `Tensor` representing the shape of the einsum operand
|
|
corresponding to `subscripts`.
|
|
subscripts: A string denoting the einsum subscript.
|
|
|
|
Returns:
|
|
reduced_subs: Subscripts formed by a concatenation of labels in
|
|
`reduced_label_set`.
|
|
reduced_dims: Dimensions from `input_shape` corresponding to each label
|
|
in `reduced_subs`.
|
|
reduced_axes: Axes described by `subscripts` corresponding to each label
|
|
in `reduced_subs`. If there are multiple occurrences in `subscripts`,
|
|
we consider only the leftmost one.
|
|
|
|
"""
|
|
# Concatenate the sequence of reduced axis labels.
|
|
reduced_subs = "".join(list(reduced_label_set))
|
|
# Get the axis (may be positive, negative or zero) for each of the reduced
|
|
# labels. If the same label appears multiple times, get the left-most axis.
|
|
reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs]
|
|
# Get the corresponding dimensions for each reduced axis.
|
|
reduced_dims = array_ops.stack([input_shape[ax] for ax in reduced_axes])
|
|
return reduced_subs, reduced_dims, reduced_axes
|
|
|
|
def _GetGradReduced(output_grad, output_subs, input_subs, input_shape,
|
|
reduced_label_set):
|
|
"""Returns the gradient wrt input for a unary einsum with reductions.
|
|
|
|
Args:
|
|
output_grad: The gradient wrt the output of a unary einsum operation.
|
|
output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`).
|
|
input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`).
|
|
input_shape: A `Tensor` representing the shape of the input operand.
|
|
reduced_label_set: The set of axis labels appearing in `input_subs` but
|
|
not in `output_subs`.
|
|
"""
|
|
# Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and
|
|
# 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced
|
|
# subscripts "bd", corresponding dimensions [5,4] and axes [2,5].
|
|
reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts(
|
|
reduced_label_set, input_shape, input_subs)
|
|
# Whether either the input or the output subscripts have a repeated label.
|
|
# This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca".
|
|
has_repeated_labels = (
|
|
len(set(input_subs)) + len(set(output_subs)) <
|
|
len(input_subs) + len(output_subs))
|
|
# Compute the input subscripts without the reduced axis labels, e.g. "aac"
|
|
# for the equation "aabbcd->ca".
|
|
input_subs_without_reduced_labels = "".join(
|
|
[s for s in input_subs if s not in reduced_label_set])
|
|
|
|
# The gradient wrt the input for the equation "abc->ac" (or, equivalently
|
|
# reduce_sum(..., axis=1)) is just the gradient of the output tiled N times
|
|
# along axis 1, where label 'b' represents a dimension of size N.
|
|
#
|
|
# If we're not dealing with repeated labels, and the non-reduced labels
|
|
# doesn't need to be transposed, then just tiling is enough and there is no
|
|
# need to call another einsum. For example, tiling is sufficient for
|
|
# "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or
|
|
# "abc->ca" (transpose), we'd need another einsum operation after tiling.
|
|
if (not has_repeated_labels and
|
|
input_subs_without_reduced_labels == output_subs):
|
|
# Obtain the shape of the output, as if keepdims=True on reduce sum. E.g.
|
|
# for the equation "abcd->ac" with input shape [2,5,3,4], we get the
|
|
# reduced shape [2,1,3,1].
|
|
reduced_shape = math_ops.reduced_shape(
|
|
input_shape, ops.convert_to_tensor(reduced_axes))
|
|
# Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to
|
|
# the shape [2,5,3,4] results in the gradient wrt "abcd".
|
|
return array_ops.broadcast_to(
|
|
array_ops.reshape(output_grad, reduced_shape), input_shape)
|
|
|
|
# If we *do* have traces or transpose operations, then prepend the extra
|
|
# reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd
|
|
# first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca".
|
|
#
|
|
# Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2].
|
|
# This is the shape of the intermediate "bdca".
|
|
grad_shape_with_reduced_labels = array_ops.concat(
|
|
[reduced_dims, array_ops.shape(output_grad)], axis=0)
|
|
# Obtain the output shape of the reduction-only equation "bdca->ca" as if
|
|
# keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we
|
|
# just have to prepend that many 1s to the output shape.
|
|
reduced_shape = (
|
|
array_ops.concat([
|
|
array_ops.ones(len(reduced_label_set), dtype=dtypes.int32),
|
|
array_ops.shape(output_grad)
|
|
],
|
|
axis=0))
|
|
# Compute the VJP for the intermediate (viz. "bdca->ca") for which
|
|
# broadcasting is sufficient.
|
|
broadcasted_grad = array_ops.broadcast_to(
|
|
array_ops.reshape(output_grad, reduced_shape),
|
|
grad_shape_with_reduced_labels)
|
|
# Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use
|
|
# einsum with the input and output subscripts reversed (viz. "bdca->aabbcd")
|
|
# since the output axis labels now appear in the input subscripts.
|
|
return gen_linalg_ops.einsum([broadcasted_grad],
|
|
"{}->{}".format(reduced_subs + output_subs,
|
|
input_subs))
|
|
|
|
def _GetGradWrt(output_grad, other_operand, input_shape, input_subs,
|
|
other_subs, output_subs):
|
|
"""Returns the gradient wrt an input operand for a binary einsum.
|
|
|
|
This function does not handle (un)broadcasting. This must be done separately
|
|
on the returned gradient.
|
|
|
|
Args:
|
|
output_grad: The gradient wrt the output of a binary einsum operation.
|
|
other_operand: The complementary `Tensor` operand i.e. which is not the
|
|
input operand.
|
|
input_shape: A `Tensor` representing the shape of input operand.
|
|
input_subs: The subscripts of the input operand.
|
|
other_subs: The subscripts of the complementary operand.
|
|
output_subs: The output subscripts.
|
|
"""
|
|
# Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y),
|
|
# where the equation involves only Tensor contractions, generalized traces
|
|
# and transposes, the input gradients are given by the vector-jacobian
|
|
# products (VJPs):
|
|
#
|
|
# grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z)
|
|
# grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z}
|
|
#
|
|
# where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs
|
|
# x and y and grad_wrt_z is the given gradient with respect to output z.
|
|
#
|
|
# Proof: For unary einsum equations involving only transpose ("ij->ji") and
|
|
# traces ("ii->i"), the linear mapping's Jacobian at input x is given
|
|
# by the function itself. We can verify that the linear map given by the
|
|
# VJP are einsums with the equations "ji->ij" and "i->ii" respectively,
|
|
# where the latter represents 'un-tracing', or filling the diagonal with
|
|
# the input axis and non-diagonal entries are zeros.
|
|
# Furthermore, recall that matrix multiplication, which is
|
|
# represented by the equation "ab,bc->ac", has its VJPs given by the
|
|
# einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example
|
|
# https://math.stackexchange.com/a/2755680). Combined with transposes and
|
|
# traces we can rewrite Tensor contractions as regular matrix
|
|
# multiplication. Since each of these operations have their VJPs described
|
|
# by einsums of the required pattern, the result follows.
|
|
#
|
|
# Accordingly, einsum operations except for those with reductions, e.g.
|
|
# "abc,cd->ad" have their VJPs defined by:
|
|
# "{output_subs},{other_subs}->{input_subs}".
|
|
#
|
|
# But if there is a reduction, this would lead to the equation "ad,cd->abc"
|
|
# which is invalid because the reduced axis label 'b' is present in the
|
|
# output but not in any of the inputs. Therefore, we compute the VJP in two
|
|
# steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of
|
|
# "abc->ac" or, equivalently, reduce_sum(..., axis=1).
|
|
#
|
|
# Compute the set of input axis labels which doesn't appear in either the
|
|
# output subscripts or the other operand's subscript. E.g. the set {'b'} for
|
|
# the equation "abc,cd->ad".
|
|
reduced_label_set = set(input_subs).difference(
|
|
set(output_subs + other_subs + "."))
|
|
# Obtain the input subscripts with the reduced axis labels removed. E.g.
|
|
# "ac" in the above example.
|
|
left_subs = "".join(s for s in input_subs if s not in reduced_label_set)
|
|
|
|
# Compute the gradient wrt the input, without accounting for the operation
|
|
# "abc->ac". So, now we have the VJP of the operation "ac,cd->ad".
|
|
grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand],
|
|
"{},{}->{}".format(
|
|
output_subs, other_subs,
|
|
left_subs))
|
|
# If the reduced_label_set is empty, then we already have the gradient
|
|
# wrt the input.
|
|
if not reduced_label_set:
|
|
return grad_reduced
|
|
# Otherwise, we currently have the gradient wrt the output of the reduction
|
|
# operation "abc->ac". Invoke the subroutine for the gradient for unary
|
|
# einsum with reductions.
|
|
return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape,
|
|
reduced_label_set)
|
|
|
|
equation = op.get_attr("equation")
|
|
if isinstance(equation, bytes):
|
|
equation = equation.decode()
|
|
input_subs, output_subs = equation.split("->")
|
|
|
|
if len(op.inputs) == 1:
|
|
# For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the
|
|
# input (VJP) is given by the reversed equation:
|
|
# grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z)
|
|
# (See the justification in _GetGradWrt). This is valid unless there are
|
|
# reduced axis labels; i.e. axis labels appearing in the input but not in
|
|
# the output subscripts.
|
|
input_shape = array_ops.shape(op.inputs[0])
|
|
# Find the axis labels which appear only in the input.
|
|
reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis))
|
|
if not reduced_label_set:
|
|
# Return the einsum given by the reversed equation, since we don't have
|
|
# reduced axes.
|
|
return gen_linalg_ops.einsum([grad],
|
|
"{}->{}".format(output_subs, input_subs))
|
|
# We do have reduced axes, so we invoke the subroutine for reduced unary
|
|
# einsums.
|
|
return _GetGradReduced(grad, output_subs, input_subs, input_shape,
|
|
reduced_label_set)
|
|
|
|
x_subs, y_subs = input_subs.split(",")
|
|
# Add ellipsis for broadcasted dimensions if any operand does not have it.
|
|
# This is because the equation "...ij,jk->ik" may be valid if the 0th input's
|
|
# batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
|
|
# because only the output subscripts contain ellipsis.
|
|
if ellipsis in output_subs:
|
|
if ellipsis not in x_subs:
|
|
x_subs += ellipsis
|
|
if ellipsis not in y_subs:
|
|
y_subs += ellipsis
|
|
|
|
# Obtain the gradients wrt the inputs x and y, without taking into account
|
|
# the unbroadcasting.
|
|
x, y = op.inputs[0], op.inputs[1]
|
|
if grad.dtype.is_complex:
|
|
x = math_ops.conj(x)
|
|
y = math_ops.conj(y)
|
|
|
|
x_shape = array_ops.shape(x)
|
|
y_shape = array_ops.shape(y)
|
|
grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs)
|
|
grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs)
|
|
|
|
if ellipsis not in output_subs:
|
|
# If no ellipsis in the output; then no need to unbroadcast.
|
|
return grad_x, grad_y
|
|
|
|
# Below we handle the case that broadcasting between x and y was necessary,
|
|
# with x and y having possibly different batch shapes.
|
|
|
|
# Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c'
|
|
# and shape of rank 10; the range [3:-1] denotes the broadcasted axes.
|
|
bx_start, bx_end = _GetBcastSubshape(x_subs)
|
|
by_start, by_end = _GetBcastSubshape(y_subs)
|
|
# If the static batch shapes are equal, we don't need to unbroadcast.
|
|
x_shape_static = x.get_shape()
|
|
y_shape_static = y.get_shape()
|
|
if (x_shape_static.is_fully_defined() and
|
|
y_shape_static.is_fully_defined() and
|
|
x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]):
|
|
return grad_x, grad_y
|
|
|
|
# Sum the gradient across the broadcasted axes.
|
|
rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end],
|
|
y_shape[by_start:by_end])
|
|
grad_x = array_ops.reshape(
|
|
math_ops.reduce_sum(grad_x, bx_start + rx), x_shape)
|
|
grad_y = array_ops.reshape(
|
|
math_ops.reduce_sum(grad_y, by_start + ry), y_shape)
|
|
return grad_x, grad_y
|
|
|
|
|
|
@ops.RegisterGradient("MatrixDeterminant")
|
|
def _MatrixDeterminantGrad(op, grad):
|
|
"""Gradient for MatrixDeterminant."""
|
|
a = op.inputs[0]
|
|
c = op.outputs[0]
|
|
a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
|
|
multipliers = array_ops.reshape(grad * c,
|
|
array_ops.concat([array_ops.shape(c), [1, 1]],
|
|
0))
|
|
return multipliers * a_adj_inv
|
|
|
|
|
|
@ops.RegisterGradient("MatrixSquareRoot")
|
|
def _MatrixSquareRootGrad(op, grad):
|
|
"""Gradient for MatrixSquareRoot."""
|
|
|
|
# Let A be an m x m square matrix (or batch of matrices)
|
|
# Let R = sqrtm(A)
|
|
# By definition, A = RR
|
|
# Take the differential: dA = d(RR) = RdR + dRR
|
|
# Solve the resulting Sylvester equation for dR
|
|
|
|
# Used to find Kronecker products within the Sylvester equation
|
|
def _KroneckerProduct(b1, b2):
|
|
"""Computes the Kronecker product of two batches of square matrices."""
|
|
b1_shape = array_ops.shape(b1)
|
|
b2_shape = array_ops.shape(b2)
|
|
b1_order = b1_shape[-1]
|
|
b2_order = b2_shape[-1]
|
|
|
|
shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)]
|
|
shape_slice = array_ops.slice(b1_shape, [0],
|
|
shape_slice_size) # Same for both batches
|
|
b1_reshape_shape = array_ops.concat(
|
|
[shape_slice, [b1_order], [1], [b1_order], [1]], 0)
|
|
b2_reshape_shape = array_ops.concat(
|
|
[shape_slice, [1], [b2_order], [1], [b2_order]], 0)
|
|
|
|
b1_reshape = array_ops.reshape(b1, b1_reshape_shape)
|
|
b2_reshape = array_ops.reshape(b2, b2_reshape_shape)
|
|
|
|
order_prod = b1_order * b2_order
|
|
kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0)
|
|
return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape)
|
|
|
|
sqrtm = op.outputs[0] # R
|
|
shape = array_ops.shape(sqrtm)
|
|
order = shape[-1] # m
|
|
matrix_count = math_ops.reduce_prod(shape[0:-2])
|
|
|
|
# Get batch of m x m identity matrices
|
|
eye = linalg_ops.eye(order, dtype=sqrtm.dtype) # m x m identity matrix
|
|
eye_flat = array_ops.reshape(eye, [-1])
|
|
eye_tiled = array_ops.tile(eye_flat, [matrix_count])
|
|
eye_batch = array_ops.reshape(eye_tiled, shape)
|
|
|
|
# The transpose of R is taken in the k1 term instead of k2 in
|
|
# order to prevent redundant transposition of R (i.e. (R')' = R)
|
|
sqrtm_transpose = array_ops.matrix_transpose(sqrtm)
|
|
k1 = _KroneckerProduct(eye_batch, sqrtm_transpose)
|
|
k2 = _KroneckerProduct(sqrtm, eye_batch)
|
|
ksum = math_ops.add(k1, k2)
|
|
|
|
# Vectorize dA
|
|
shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)]
|
|
shape_slice = array_ops.slice(shape, [0], shape_slice_size)
|
|
shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0)
|
|
vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da)
|
|
|
|
# Solve for vec(dR)
|
|
vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da)
|
|
|
|
# Solve for dR by inverse vectorizing vec(dR)
|
|
dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape)
|
|
return array_ops.matrix_transpose(dsqrtm_transpose)
|
|
|
|
|
|
@ops.RegisterGradient("LogMatrixDeterminant")
|
|
def _LogMatrixDeterminantGrad(op, _, grad_b):
|
|
"""Gradient for LogMatrixDeterminant."""
|
|
a = op.inputs[0]
|
|
c = op.outputs[1]
|
|
a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
|
|
multipliers = array_ops.reshape(
|
|
grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
|
|
return multipliers * a_adj_inv
|
|
|
|
|
|
@ops.RegisterGradient("Cholesky")
|
|
def _CholeskyGrad(op, grad):
|
|
"""Gradient for Cholesky."""
|
|
|
|
# Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
|
|
l = op.outputs[0]
|
|
num_rows = array_ops.shape(l)[-1]
|
|
batch_shape = array_ops.shape(l)[:-2]
|
|
l_inverse = linalg_ops.matrix_triangular_solve(l,
|
|
linalg_ops.eye(
|
|
num_rows,
|
|
batch_shape=batch_shape,
|
|
dtype=l.dtype))
|
|
|
|
middle = math_ops.matmul(l, grad, adjoint_a=True)
|
|
middle = array_ops.matrix_set_diag(middle,
|
|
0.5 * array_ops.matrix_diag_part(middle))
|
|
middle = array_ops.matrix_band_part(middle, -1, 0)
|
|
|
|
grad_a = math_ops.matmul(
|
|
math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
|
|
|
|
grad_a += _linalg.adjoint(grad_a)
|
|
return grad_a * 0.5
|
|
|
|
|
|
@ops.RegisterGradient("Qr")
|
|
def _QrGrad(op, dq, dr):
|
|
"""Gradient for Qr."""
|
|
q, r = op.outputs
|
|
if q.dtype.is_complex:
|
|
raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
|
|
if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
|
|
r.shape.as_list()[-1] is None):
|
|
raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
|
|
if (r.shape.dims[-2].value > r.shape.dims[-1].value and
|
|
q.shape.dims[-2].value == q.shape.dims[-1].value):
|
|
raise NotImplementedError("QrGrad not implemented when nrows > ncols "
|
|
"and full_matrices is true.")
|
|
|
|
def _TriangularSolve(x, r):
|
|
"""Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
|
|
return _linalg.adjoint(
|
|
linalg_ops.matrix_triangular_solve(
|
|
r, _linalg.adjoint(x), lower=False, adjoint=False))
|
|
|
|
def _QrGradSquareAndDeepMatrices(q, r, dq, dr):
|
|
"""Gradient for matrix orders num_rows >= num_cols
|
|
and full_matrices is false.
|
|
"""
|
|
qdq = math_ops.matmul(q, dq, adjoint_a=True)
|
|
qdq_ = qdq - _linalg.adjoint(qdq)
|
|
rdr = math_ops.matmul(r, dr, adjoint_b=True)
|
|
rdr_ = rdr - _linalg.adjoint(rdr)
|
|
tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
|
|
|
|
grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
|
|
grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
|
|
return grad_a + grad_b
|
|
|
|
num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]
|
|
|
|
if num_rows >= num_cols:
|
|
return _QrGradSquareAndDeepMatrices(q, r, dq, dr)
|
|
|
|
# Partition a = [x, y], r = [u, v] and reduce to the square case
|
|
# The methodology is explained in detail in https://arxiv.org/abs/2009.10071
|
|
a = op.inputs[0]
|
|
y = a[..., :, num_rows:]
|
|
u = r[..., :, :num_rows]
|
|
dv = dr[..., :, num_rows:]
|
|
du = dr[..., :, :num_rows]
|
|
dy = math_ops.matmul(q, dv)
|
|
dx = _QrGradSquareAndDeepMatrices(q, u,
|
|
dq + math_ops.matmul(y, dv, adjoint_b=True),
|
|
du)
|
|
return array_ops.concat([dx, dy], axis=-1)
|
|
|
|
|
|
@ops.RegisterGradient("MatrixSolve")
|
|
def _MatrixSolveGrad(op, grad):
|
|
"""Gradient for MatrixSolve."""
|
|
a = op.inputs[0]
|
|
adjoint_a = op.get_attr("adjoint")
|
|
c = op.outputs[0]
|
|
grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
|
|
if adjoint_a:
|
|
grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
|
|
else:
|
|
grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
|
|
return (grad_a, grad_b)
|
|
|
|
|
|
@ops.RegisterGradient("MatrixSolveLs")
|
|
def _MatrixSolveLsGrad(op, grad):
|
|
"""Gradients for MatrixSolveLs."""
|
|
|
|
# TODO(rmlarsen): The implementation could be more efficient:
|
|
# a) Output the Cholesky factorization from forward op instead of
|
|
# recomputing it here.
|
|
# b) Implement a symmetric rank-k update op instead of computing
|
|
# x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
|
|
|
|
def _Overdetermined(op, grad):
|
|
"""Gradients for the overdetermined case of MatrixSolveLs.
|
|
|
|
This is the backprop for the solution to the normal equations of the first
|
|
kind:
|
|
X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
|
|
which solve the least squares problem
|
|
min ||A * X - B||_F^2 + lambda ||X||_F^2.
|
|
"""
|
|
a = op.inputs[0]
|
|
b = op.inputs[1]
|
|
x = op.outputs[0]
|
|
l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
|
|
# pylint: disable=protected-access
|
|
chol = linalg_ops._RegularizedGramianCholesky(
|
|
a, l2_regularizer=l2_regularizer, first_kind=True)
|
|
# pylint: enable=protected-access
|
|
# Temporary z = (A^T * A + lambda * I)^{-1} * grad.
|
|
z = linalg_ops.cholesky_solve(chol, grad)
|
|
xzt = math_ops.matmul(x, z, adjoint_b=True)
|
|
zx_sym = xzt + array_ops.matrix_transpose(xzt)
|
|
grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True)
|
|
grad_b = math_ops.matmul(a, z)
|
|
return (grad_a, grad_b, None)
|
|
|
|
def _Underdetermined(op, grad):
|
|
"""Gradients for the underdetermined case of MatrixSolveLs.
|
|
|
|
This is the backprop for the solution to the normal equations of the second
|
|
kind:
|
|
X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
|
|
that (for lambda=0) solve the least squares problem
|
|
min ||X||_F subject to A*X = B.
|
|
"""
|
|
a = op.inputs[0]
|
|
b = op.inputs[1]
|
|
l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
|
|
# pylint: disable=protected-access
|
|
chol = linalg_ops._RegularizedGramianCholesky(
|
|
a, l2_regularizer=l2_regularizer, first_kind=False)
|
|
# pylint: enable=protected-access
|
|
grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad))
|
|
# Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
|
|
tmp = linalg_ops.cholesky_solve(chol, b)
|
|
a1 = math_ops.matmul(tmp, a, adjoint_a=True)
|
|
a1 = -math_ops.matmul(grad_b, a1)
|
|
a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True)
|
|
a2 = math_ops.matmul(tmp, a2, adjoint_b=True)
|
|
grad_a = a1 + a2
|
|
return (grad_a, grad_b, None)
|
|
|
|
fast = op.get_attr("fast")
|
|
if fast is False:
|
|
raise ValueError("Gradient not defined for fast=False")
|
|
matrix_shape = op.inputs[0].get_shape()[-2:]
|
|
if matrix_shape.is_fully_defined():
|
|
if matrix_shape[-2] >= matrix_shape[-1]:
|
|
return _Overdetermined(op, grad)
|
|
else:
|
|
return _Underdetermined(op, grad)
|
|
else:
|
|
# We have to defer determining the shape to runtime and use
|
|
# conditional execution of the appropriate graph.
|
|
matrix_shape = array_ops.shape(op.inputs[0])[-2:]
|
|
return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
|
|
lambda: _Overdetermined(op, grad),
|
|
lambda: _Underdetermined(op, grad))
|
|
|
|
|
|
@ops.RegisterGradient("BandedTriangularSolve")
|
|
def _BandedTriangularSolveGrad(op, grad):
|
|
"""Gradient for BandedTriangularSolve."""
|
|
a = op.inputs[0]
|
|
b = op.inputs[1]
|
|
num_bands = array_ops.shape(a)[-2]
|
|
adjoint_a = op.get_attr("adjoint")
|
|
lower_a = op.get_attr("lower")
|
|
c = op.outputs[0]
|
|
grad_b = linalg_ops.banded_triangular_solve(
|
|
a, grad, lower=lower_a, adjoint=not adjoint_a)
|
|
if adjoint_a:
|
|
grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
|
|
else:
|
|
grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
|
|
if lower_a:
|
|
grad_a = array_ops.matrix_diag_part(
|
|
grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT")
|
|
else:
|
|
grad_a = array_ops.matrix_diag_part(
|
|
grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT")
|
|
# If the static batch shapes are equal, we don't need to unbroadcast.
|
|
if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
|
|
a.shape[:-2] == b.shape[:-2]):
|
|
return grad_a, grad_b
|
|
a_shape = array_ops.shape(a)
|
|
b_shape = array_ops.shape(b)
|
|
ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
|
|
grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
|
|
grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
|
|
return grad_a, grad_b
|
|
|
|
|
|
@ops.RegisterGradient("MatrixTriangularSolve")
|
|
def _MatrixTriangularSolveGrad(op, grad):
|
|
"""Gradient for MatrixTriangularSolve."""
|
|
a = op.inputs[0]
|
|
b = op.inputs[1]
|
|
adjoint_a = op.get_attr("adjoint")
|
|
lower_a = op.get_attr("lower")
|
|
c = op.outputs[0]
|
|
grad_b = linalg_ops.matrix_triangular_solve(
|
|
a, grad, lower=lower_a, adjoint=not adjoint_a)
|
|
if adjoint_a:
|
|
grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
|
|
else:
|
|
grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
|
|
if lower_a:
|
|
grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
|
|
else:
|
|
grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
|
|
# If the static batch shapes are equal, we don't need to unbroadcast.
|
|
if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
|
|
a.shape[:-2] == b.shape[:-2]):
|
|
return grad_a, grad_b
|
|
a_shape = array_ops.shape(a)
|
|
b_shape = array_ops.shape(b)
|
|
ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
|
|
grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
|
|
grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
|
|
return grad_a, grad_b
|
|
|
|
|
|
# To avoid nan in cases with degenerate eigenvalues or
|
|
# degenerate/zero singular values in calculations of
|
|
# f and s_inv_mat, we introduce a Lorentz broadening.
|
|
def _SafeReciprocal(x, epsilon=1E-20):
|
|
return x * math_ops.reciprocal(x * x + epsilon)
|
|
|
|
|
|
@ops.RegisterGradient("Eig")
|
|
def _EigGrad(op, grad_e, grad_v):
|
|
"""Gradient for Eig.
|
|
|
|
Based on eq. 4.77 from paper by
|
|
Christoph Boeddeker et al.
|
|
https://arxiv.org/abs/1701.00392
|
|
See also
|
|
"Computation of eigenvalue and eigenvector derivatives
|
|
for a general complex-valued eigensystem" by Nico van der Aa.
|
|
As for now only distinct eigenvalue case is considered.
|
|
"""
|
|
e = op.outputs[0]
|
|
compute_v = op.get_attr("compute_v")
|
|
# a = op.inputs[0], which satisfies
|
|
# a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
|
|
with ops.control_dependencies([grad_e, grad_v]):
|
|
if compute_v:
|
|
v = op.outputs[1]
|
|
vt = _linalg.adjoint(v)
|
|
# Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
|
|
# Notice that because of the term involving f, the gradient becomes
|
|
# infinite (or NaN in practice) when eigenvalues are not unique.
|
|
# Mathematically this should not be surprising, since for (k-fold)
|
|
# degenerate eigenvalues, the corresponding eigenvectors are only defined
|
|
# up to arbitrary rotation in a (k-dimensional) subspace.
|
|
f = array_ops.matrix_set_diag(
|
|
_SafeReciprocal(
|
|
array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
|
|
array_ops.zeros_like(e))
|
|
f = math_ops.conj(f)
|
|
vgv = math_ops.matmul(vt, grad_v)
|
|
mid = array_ops.matrix_diag(grad_e)
|
|
diag_grad_part = array_ops.matrix_diag(
|
|
array_ops.matrix_diag_part(
|
|
math_ops.cast(math_ops.real(vgv), vgv.dtype)))
|
|
mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part))
|
|
# vt is formally invertible as long as the original matrix is
|
|
# diagonalizable. However, in practice, vt may
|
|
# be ill-conditioned when matrix original matrix is close to
|
|
# non-diagonalizable one
|
|
grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt))
|
|
else:
|
|
_, v = linalg_ops.eig(op.inputs[0])
|
|
vt = _linalg.adjoint(v)
|
|
# vt is formally invertible as long as the original matrix is
|
|
# diagonalizable. However, in practice, vt may
|
|
# be ill-conditioned when matrix original matrix is close to
|
|
# non-diagonalizable one
|
|
grad_a = linalg_ops.matrix_solve(
|
|
vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt))
|
|
return math_ops.cast(grad_a, op.inputs[0].dtype)
|
|
|
|
|
|
@ops.RegisterGradient("SelfAdjointEigV2")
|
|
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
|
|
"""Gradient for SelfAdjointEigV2."""
|
|
e = op.outputs[0]
|
|
compute_v = op.get_attr("compute_v")
|
|
# a = op.inputs[0], which satisfies
|
|
# a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
|
|
with ops.control_dependencies([grad_e, grad_v]):
|
|
if compute_v:
|
|
v = op.outputs[1]
|
|
# Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
|
|
# Notice that because of the term involving f, the gradient becomes
|
|
# infinite (or NaN in practice) when eigenvalues are not unique.
|
|
# Mathematically this should not be surprising, since for (k-fold)
|
|
# degenerate eigenvalues, the corresponding eigenvectors are only defined
|
|
# up to arbitrary rotation in a (k-dimensional) subspace.
|
|
f = array_ops.matrix_set_diag(
|
|
_SafeReciprocal(
|
|
array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
|
|
array_ops.zeros_like(e))
|
|
grad_a = math_ops.matmul(
|
|
v,
|
|
math_ops.matmul(
|
|
array_ops.matrix_diag(grad_e) +
|
|
f * math_ops.matmul(v, grad_v, adjoint_a=True),
|
|
v,
|
|
adjoint_b=True))
|
|
else:
|
|
_, v = linalg_ops.self_adjoint_eig(op.inputs[0])
|
|
grad_a = math_ops.matmul(v,
|
|
math_ops.matmul(
|
|
array_ops.matrix_diag(grad_e),
|
|
v,
|
|
adjoint_b=True))
|
|
# The forward op only depends on the lower triangular part of a, so here we
|
|
# symmetrize and take the lower triangle
|
|
grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0)
|
|
grad_a = array_ops.matrix_set_diag(grad_a,
|
|
0.5 * array_ops.matrix_diag_part(grad_a))
|
|
return grad_a
|
|
|
|
|
|
@ops.RegisterGradient("Svd")
|
|
def _SvdGrad(op, grad_s, grad_u, grad_v):
|
|
"""Gradient for the singular value decomposition."""
|
|
|
|
# The derivation for the compute_uv=False case, and most of
|
|
# the derivation for the full_matrices=True case, are in
|
|
# Giles' paper (see reference at top of file). A derivation for
|
|
# the full_matrices=False case is available at
|
|
# https://j-towns.github.io/papers/svd-derivative.pdf
|
|
# The derivation for complex valued SVD can be found in
|
|
# https://re-ra.xyz/misc/complexsvd.pdf or
|
|
# https://giggleliu.github.io/2019/04/02/einsumbp.html
|
|
a = op.inputs[0]
|
|
a_shape = a.get_shape().with_rank_at_least(2)
|
|
grad_s = math_ops.cast(grad_s, a.dtype)
|
|
grad_s_mat = array_ops.matrix_diag(grad_s)
|
|
|
|
if not op.get_attr("compute_uv"):
|
|
s, u, v = linalg_ops.svd(a, compute_uv=True)
|
|
grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
|
|
grad_a.set_shape(a_shape)
|
|
return grad_a
|
|
|
|
full_matrices = op.get_attr("full_matrices")
|
|
|
|
grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
|
|
grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
|
|
m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
|
|
n = a_shape.dims[-1].merge_with(grad_v_shape[-2])
|
|
batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
|
|
grad_v_shape[:-2])
|
|
a_shape = batch_shape.concatenate([m, n])
|
|
|
|
m = a_shape.dims[-2].value
|
|
n = a_shape.dims[-1].value
|
|
# TODO(rmlarsen): Make this work with placeholders.
|
|
if m is None or n is None:
|
|
raise NotImplementedError(
|
|
"SVD gradient has not been implemented for input with unknown "
|
|
"inner matrix shape.")
|
|
|
|
s = op.outputs[0]
|
|
u = op.outputs[1]
|
|
v = op.outputs[2]
|
|
s = math_ops.cast(s, a.dtype)
|
|
|
|
use_adjoint = False
|
|
if m > n:
|
|
# Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
|
|
# Hermitian transpose of the gradient at the end.
|
|
use_adjoint = True
|
|
m, n = n, m
|
|
u, v = v, u
|
|
grad_u, grad_v = grad_v, grad_u
|
|
|
|
with ops.control_dependencies([grad_s, grad_u, grad_v]):
|
|
if full_matrices and abs(m - n) > 1:
|
|
raise NotImplementedError(
|
|
"svd gradient is not implemented for abs(m - n) > 1 "
|
|
"when full_matrices is True")
|
|
s_mat = array_ops.matrix_diag(s)
|
|
s2 = math_ops.square(s)
|
|
|
|
# NOTICE: Because of the term involving f, the gradient becomes
|
|
# infinite (or NaN in practice) when singular values are not unique.
|
|
# Mathematically this should not be surprising, since for (k-fold)
|
|
# degenerate singular values, the corresponding singular vectors are
|
|
# only defined up a (k-dimensional) subspace. In practice, this can
|
|
# lead to numerical instability when singular values are close but not
|
|
# exactly equal.
|
|
|
|
s_shape = array_ops.shape(s)
|
|
f = array_ops.matrix_set_diag(
|
|
_SafeReciprocal(
|
|
array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
|
|
array_ops.zeros_like(s))
|
|
s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
|
|
|
|
v1 = v[..., :, :m]
|
|
grad_v1 = grad_v[..., :, :m]
|
|
|
|
u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
|
|
v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)
|
|
|
|
f_u = f * u_gu
|
|
f_v = f * v_gv
|
|
|
|
term1_nouv = (
|
|
grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
|
|
math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))
|
|
|
|
term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))
|
|
|
|
if m == n:
|
|
grad_a_before_transpose = term1
|
|
else:
|
|
gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True)
|
|
gv1t_v1 = math_ops.matmul(gv1t, v1)
|
|
term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
|
|
|
|
if full_matrices:
|
|
v2 = v[..., :, m:n]
|
|
grad_v2 = grad_v[..., :, m:n]
|
|
|
|
v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
|
|
term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)
|
|
|
|
u_s_inv = math_ops.matmul(u, s_inv_mat)
|
|
term2 = math_ops.matmul(u_s_inv, term2_nous)
|
|
|
|
grad_a_before_transpose = term1 + term2
|
|
|
|
if a.dtype.is_complex:
|
|
eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
|
|
l = eye * v_gv
|
|
term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l)
|
|
term3 = 1 / 2. * math_ops.matmul(
|
|
u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
|
|
|
|
grad_a_before_transpose += term3
|
|
|
|
if use_adjoint:
|
|
grad_a = array_ops.matrix_transpose(
|
|
grad_a_before_transpose, conjugate=True)
|
|
else:
|
|
grad_a = grad_a_before_transpose
|
|
|
|
grad_a.set_shape(a_shape)
|
|
return grad_a
|
|
|
|
|
|
def _LeftShift(x):
|
|
"""Shifts next-to-last dimension to the left, adding zero on the right."""
|
|
rank = array_ops.rank(x)
|
|
zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
|
|
pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0)
|
|
return array_ops.pad(x[..., 1:, :], pad)
|
|
|
|
|
|
def _RightShift(x):
|
|
"""Shifts next-to-last dimension to the right, adding zero on the left."""
|
|
rank = array_ops.rank(x)
|
|
zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
|
|
pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0)
|
|
return array_ops.pad(x[..., :-1, :], pad)
|
|
|
|
|
|
@ops.RegisterGradient("TridiagonalMatMul")
|
|
def _TridiagonalMatMulGrad(op, grad):
|
|
"""Gradient for TridiagonalMatMul."""
|
|
superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True)
|
|
maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True)
|
|
subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True)
|
|
rhs_conj = math_ops.conj(op.inputs[3])
|
|
|
|
superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1)
|
|
maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1)
|
|
subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1)
|
|
rhs_grad = _RightShift(superdiag_conj * grad) + \
|
|
maindiag_conj * grad + _LeftShift(subdiag_conj * grad)
|
|
|
|
superdiag_grad = array_ops.expand_dims(superdiag_grad, -2)
|
|
maindiag_grad = array_ops.expand_dims(maindiag_grad, -2)
|
|
subdiag_grad = array_ops.expand_dims(subdiag_grad, -2)
|
|
|
|
return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad
|
|
|
|
|
|
@ops.RegisterGradient("TridiagonalSolve")
|
|
def _TridiagonalSolveGrad(op, grad):
|
|
"""Gradient for TridiagonalSolveGrad."""
|
|
diags = op.inputs[0]
|
|
x = op.outputs[0]
|
|
partial_pivoting = op.get_attr("partial_pivoting")
|
|
|
|
# Transposing the matrix within tridiagonal_solve kernel by interchanging
|
|
# superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with
|
|
# paddings required by cusparse*gtsv routines.
|
|
# So constructing the transposed matrix in Python.
|
|
diags_transposed = _TransposeTridiagonalMatrix(diags)
|
|
|
|
grad_rhs = linalg_ops.tridiagonal_solve(diags_transposed, grad,
|
|
partial_pivoting=partial_pivoting)
|
|
grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x)
|
|
return grad_diags, grad_rhs
|
|
|
|
|
|
def _TransposeTridiagonalMatrix(diags):
|
|
"""Transposes a tridiagonal matrix.
|
|
|
|
Args:
|
|
diags: the diagonals of the input matrix in the compact form (see
|
|
linalg_ops.tridiagonal_solve).
|
|
|
|
Returns:
|
|
Diagonals of the transposed matrix in the compact form.
|
|
"""
|
|
|
|
diag = diags[..., 1, :]
|
|
|
|
if diags.shape.is_fully_defined():
|
|
# For fully defined tensor we can concat with a tensor of zeros, which is
|
|
# faster than using array_ops.pad().
|
|
zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype)
|
|
superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1)
|
|
subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1)
|
|
else:
|
|
rank = array_ops.rank(diags)
|
|
zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
|
|
superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])),
|
|
axis=0)
|
|
superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad)
|
|
subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])),
|
|
axis=0)
|
|
subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad)
|
|
return array_ops.stack([superdiag, diag, subdiag], axis=-2)
|
|
|
|
|
|
def _MatmulExtractingThreeDiagonals(x, y_tr):
|
|
"""Multiplies matrices and extracts three diagonals from the product.
|
|
|
|
With sizes M x K and K x M, this function takes O(MK) time and O(M) space,
|
|
while using math_ops.matmul, and then extracting the diagonals would take
|
|
O(M^2 K) time and O(M^2) space.
|
|
|
|
Args:
|
|
x: first matrix
|
|
y_tr: second matrix transposed
|
|
|
|
Returns:
|
|
Diagonals of the product in compact format (see
|
|
linalg_ops.tridiagonal_solve)
|
|
|
|
"""
|
|
diag = math_ops.reduce_sum(x * y_tr, axis=-1)
|
|
|
|
if y_tr.shape.is_fully_defined():
|
|
zeros = array_ops.zeros(
|
|
list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype)
|
|
superdiag = math_ops.reduce_sum(
|
|
x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1)
|
|
subdiag = math_ops.reduce_sum(
|
|
x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1)
|
|
else:
|
|
rank = array_ops.rank(y_tr)
|
|
zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
|
|
superdiag_pad = array_ops.concat(
|
|
(zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0)
|
|
superdiag = math_ops.reduce_sum(
|
|
x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1)
|
|
subdiag_pad = array_ops.concat(
|
|
(zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0)
|
|
subdiag = math_ops.reduce_sum(
|
|
x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1)
|
|
return array_ops.stack([superdiag, diag, subdiag], axis=-2)
|