diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index b8c59bcbb6c..07edd54b494 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -665,7 +665,7 @@ def _GatherV2Grad(op, grad): # For axis 0 gathers, build an appropriately shaped IndexedSlices. if axis_static == 0: if context.executing_eagerly(): - with ops.device("/cpu:0"): + with ops.device(indices_size.device): params_tail_shape = array_ops.identity(params_shape)[1:] else: params_tail_shape = params_shape[1:]