From 646dbdc2f5a4af3489028e4d32c441c87fc31841 Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Mon, 19 Oct 2020 14:18:10 -0700 Subject: [PATCH] avoid hard code device for custom device. PiperOrigin-RevId: 337931634 Change-Id: Ib3dd6f0e342dc440deed7b8b0f83cb54b35410b3 --- tensorflow/python/ops/array_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index b8c59bcbb6c..07edd54b494 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -665,7 +665,7 @@ def _GatherV2Grad(op, grad): # For axis 0 gathers, build an appropriately shaped IndexedSlices. if axis_static == 0: if context.executing_eagerly(): - with ops.device("/cpu:0"): + with ops.device(indices_size.device): params_tail_shape = array_ops.identity(params_shape)[1:] else: params_tail_shape = params_shape[1:]