avoid hard code device for custom device.
PiperOrigin-RevId: 337931634 Change-Id: Ib3dd6f0e342dc440deed7b8b0f83cb54b35410b3
This commit is contained in:
parent
710f3c83b4
commit
646dbdc2f5
@ -665,7 +665,7 @@ def _GatherV2Grad(op, grad):
|
|||||||
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
|
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
|
||||||
if axis_static == 0:
|
if axis_static == 0:
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
with ops.device("/cpu:0"):
|
with ops.device(indices_size.device):
|
||||||
params_tail_shape = array_ops.identity(params_shape)[1:]
|
params_tail_shape = array_ops.identity(params_shape)[1:]
|
||||||
else:
|
else:
|
||||||
params_tail_shape = params_shape[1:]
|
params_tail_shape = params_shape[1:]
|
||||||
|
Loading…
Reference in New Issue
Block a user