avoid hard code device for custom device.

PiperOrigin-RevId: 337931634
Change-Id: Ib3dd6f0e342dc440deed7b8b0f83cb54b35410b3
This commit is contained in:
Jianwei Xie 2020-10-19 14:18:10 -07:00 committed by TensorFlower Gardener
parent 710f3c83b4
commit 646dbdc2f5

View File

@ -665,7 +665,7 @@ def _GatherV2Grad(op, grad):
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
if axis_static == 0:
if context.executing_eagerly():
with ops.device("/cpu:0"):
with ops.device(indices_size.device):
params_tail_shape = array_ops.identity(params_shape)[1:]
else:
params_tail_shape = params_shape[1:]