Add additional constraint to place identity under CPU

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-05-22 19:12:20 +00:00
parent 4aecaf3771
commit ee92f24165

View File

@ -641,7 +641,8 @@ def _GatherV2Grad(op, grad):
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
if axis_static == 0:
if context.executing_eagerly():
params_tail_shape = array_ops.identity(params_shape)[1:]
with ops.device("/cpu:0"):
params_tail_shape = array_ops.identity(params_shape)[1:]
else:
params_tail_shape = params_shape[1:]
values_shape = array_ops.concat([indices_size, params_tail_shape], 0)