Use the scalar cache in MeanGrad.
PiperOrigin-RevId: 168462267
This commit is contained in:
parent
1cada9ea2d
commit
bf96fcd13a
@ -104,7 +104,12 @@ def _MeanGrad(op, grad):
|
||||
factor = _safe_shape_div(
|
||||
math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape))
|
||||
if context.in_eager_mode():
|
||||
factor = factor._copy(device_name=sum_grad.device) # pylint: disable=protected-access
|
||||
# Note that we go through numpy here just so we use the eager per-device
|
||||
# scalar cache. We know the factor is a host memory tensor because it's a
|
||||
# shape, and we also know that converting a scalar into a tensor triggers a
|
||||
# per-device cache.
|
||||
factor = factor.numpy()
|
||||
factor = constant_op.constant(factor, dtype=sum_grad.dtype)
|
||||
return sum_grad / math_ops.cast(factor, sum_grad.dtype), None
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user