Optimize gradients for BatchMatMulV2: When neither argument has a batch dimension, we do not need to generate the reduction epilog, since we only ever broadcast over batch dimensions. We also do not need to generate the epilog if the batch dimensions are known statically to match.

PiperOrigin-RevId: 339565202
Change-Id: I4cff4220167c96247a3f7f866d7dba9c9c89516f
This commit is contained in:
A. Unique TensorFlower 2020-10-28 16:51:11 -07:00 committed by TensorFlower Gardener
parent df5e170854
commit d9b452c778

View File

@ -1867,18 +1867,25 @@ def _BatchMatMulV2(op, grad):
grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
# Reduce along the broadcasted batch dimensions, if broadcasting is required.
# Possibly reduce along the broadcasted batch dimensions, if broadcasting
# is required.
shape_x_static = x.get_shape()
shape_y_static = y.get_shape()
if not (shape_x_static.is_fully_defined() and
shape_y_static.is_fully_defined() and
shape_x_static == shape_y_static):
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
output_may_have_non_empty_batch_shape = (
(shape_x_static.rank is None or shape_x_static.rank > 2) or
(shape_y_static.rank is None or shape_y_static.rank > 2))
batch_shapes_match = (
shape_x_static[:-2].is_fully_defined() and
shape_y_static[:-2].is_fully_defined() and
shape_x_static[:-2] == shape_y_static[:-2])
if (not output_may_have_non_empty_batch_shape) or batch_shapes_match:
return grad_x, grad_y
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
return grad_x, grad_y