Use the scalar cache in MeanGrad.

PiperOrigin-RevId: 168462267
This commit is contained in:
Alexandre Passos 2017-09-12 15:58:07 -07:00 committed by TensorFlower Gardener
parent 1cada9ea2d
commit bf96fcd13a

View File

@ -104,7 +104,12 @@ def _MeanGrad(op, grad):
factor = _safe_shape_div(
math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape))
if context.in_eager_mode():
factor = factor._copy(device_name=sum_grad.device) # pylint: disable=protected-access
# Note that we go through numpy here just so we use the eager per-device
# scalar cache. We know the factor is a host memory tensor because it's a
# shape, and we also know that converting a scalar into a tensor triggers a
# per-device cache.
factor = factor.numpy()
factor = constant_op.constant(factor, dtype=sum_grad.dtype)
return sum_grad / math_ops.cast(factor, sum_grad.dtype), None