Merge pull request #30019 from candyzone:fix_restore_placement

PiperOrigin-RevId: 262939740
This commit is contained in:
TensorFlower Gardener 2019-08-12 09:28:07 -07:00
commit 5daa70bfcf

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import six
from tensorflow.python.framework import ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
@ -116,9 +117,10 @@ def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
name: Name of the operation.
"""
base_type = variable.dtype.base_dtype
restore_op = io_ops.restore_v2(
file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0]
variable._initializer_op = state_ops.assign(variable, restore_op)
with ops.device(variable.device), ops.device("/cpu:0"):
restore_op = io_ops.restore_v2(
file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0]
variable._initializer_op = state_ops.assign(variable, restore_op)
def _set_variable_or_list_initializer(variable_or_list, file_pattern,