Optimize gradients for BatchMatMulV2: When neither argument has a batch dimension, we do not need to generate the reduction epilog, since we only ever broadcast over batch dimensions. We also do not need to generate the epilog if the batch dimensions are known statically to match.
PiperOrigin-RevId: 339565202 Change-Id: I4cff4220167c96247a3f7f866d7dba9c9c89516f
This commit is contained in:
parent
df5e170854
commit
d9b452c778
@ -1867,18 +1867,25 @@ def _BatchMatMulV2(op, grad):
|
||||
grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
|
||||
grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
|
||||
|
||||
# Reduce along the broadcasted batch dimensions, if broadcasting is required.
|
||||
# Possibly reduce along the broadcasted batch dimensions, if broadcasting
|
||||
# is required.
|
||||
shape_x_static = x.get_shape()
|
||||
shape_y_static = y.get_shape()
|
||||
if not (shape_x_static.is_fully_defined() and
|
||||
shape_y_static.is_fully_defined() and
|
||||
shape_x_static == shape_y_static):
|
||||
output_may_have_non_empty_batch_shape = (
|
||||
(shape_x_static.rank is None or shape_x_static.rank > 2) or
|
||||
(shape_y_static.rank is None or shape_y_static.rank > 2))
|
||||
batch_shapes_match = (
|
||||
shape_x_static[:-2].is_fully_defined() and
|
||||
shape_y_static[:-2].is_fully_defined() and
|
||||
shape_x_static[:-2] == shape_y_static[:-2])
|
||||
if (not output_may_have_non_empty_batch_shape) or batch_shapes_match:
|
||||
return grad_x, grad_y
|
||||
|
||||
sx = array_ops.shape(x)
|
||||
sy = array_ops.shape(y)
|
||||
rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
|
||||
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
|
||||
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
|
||||
|
||||
return grad_x, grad_y
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user