avoid hard code device for custom device.
PiperOrigin-RevId: 337931634 Change-Id: Ib3dd6f0e342dc440deed7b8b0f83cb54b35410b3
This commit is contained in:
parent
710f3c83b4
commit
646dbdc2f5
@ -665,7 +665,7 @@ def _GatherV2Grad(op, grad):
|
||||
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
|
||||
if axis_static == 0:
|
||||
if context.executing_eagerly():
|
||||
with ops.device("/cpu:0"):
|
||||
with ops.device(indices_size.device):
|
||||
params_tail_shape = array_ops.identity(params_shape)[1:]
|
||||
else:
|
||||
params_tail_shape = params_shape[1:]
|
||||
|
Loading…
Reference in New Issue
Block a user