Merge pull request #30019 from candyzone:fix_restore_placement
PiperOrigin-RevId: 262939740
This commit is contained in:
commit
5daa70bfcf
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
@ -116,9 +117,10 @@ def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
|
|||||||
name: Name of the operation.
|
name: Name of the operation.
|
||||||
"""
|
"""
|
||||||
base_type = variable.dtype.base_dtype
|
base_type = variable.dtype.base_dtype
|
||||||
restore_op = io_ops.restore_v2(
|
with ops.device(variable.device), ops.device("/cpu:0"):
|
||||||
file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0]
|
restore_op = io_ops.restore_v2(
|
||||||
variable._initializer_op = state_ops.assign(variable, restore_op)
|
file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0]
|
||||||
|
variable._initializer_op = state_ops.assign(variable, restore_op)
|
||||||
|
|
||||||
|
|
||||||
def _set_variable_or_list_initializer(variable_or_list, file_pattern,
|
def _set_variable_or_list_initializer(variable_or_list, file_pattern,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user