Merge pull request #30019 from candyzone:fix_restore_placement
PiperOrigin-RevId: 262939740
This commit is contained in:
commit
5daa70bfcf
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
@ -116,9 +117,10 @@ def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
|
||||
name: Name of the operation.
|
||||
"""
|
||||
base_type = variable.dtype.base_dtype
|
||||
restore_op = io_ops.restore_v2(
|
||||
file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0]
|
||||
variable._initializer_op = state_ops.assign(variable, restore_op)
|
||||
with ops.device(variable.device), ops.device("/cpu:0"):
|
||||
restore_op = io_ops.restore_v2(
|
||||
file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0]
|
||||
variable._initializer_op = state_ops.assign(variable, restore_op)
|
||||
|
||||
|
||||
def _set_variable_or_list_initializer(variable_or_list, file_pattern,
|
||||
|
Loading…
x
Reference in New Issue
Block a user