Wrap save/restore logic in tf.function when in eager mode. This allows parallel saving and restoring when using multiple devices.
PiperOrigin-RevId: 317719780 Change-Id: Ifb7e34f708da4121b49fb38d8dad046d45fedc42
This commit is contained in:
parent
38d95ad2d8
commit
c27b834b49
|
@ -837,7 +837,6 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
||||||
"ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
|
"ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
|
||||||
"Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
|
"Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
|
||||||
"RandomPoissonV2",
|
"RandomPoissonV2",
|
||||||
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
|
||||||
|
|
||||||
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
|
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
|
||||||
// but it can't generate any observable side-effect.
|
// but it can't generate any observable side-effect.
|
||||||
|
@ -851,7 +850,12 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
||||||
// the same device_ordinal on the same host.
|
// the same device_ordinal on the same host.
|
||||||
"EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
|
"EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
|
||||||
"EnqueueTPUEmbeddingSparseTensorBatch",
|
"EnqueueTPUEmbeddingSparseTensorBatch",
|
||||||
"EnqueueTPUEmbeddingRaggedTensorBatch"});
|
"EnqueueTPUEmbeddingRaggedTensorBatch",
|
||||||
|
|
||||||
|
// SaveV2 and RestoreV2 should be allowed to operate in parallel on
|
||||||
|
// multiple hosts.
|
||||||
|
"SaveV2", "RestoreV2"});
|
||||||
|
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
||||||
return exemption->contains(op);
|
return exemption->contains(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -172,6 +172,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||||
config.set_synchronous_execution(previous)
|
config.set_synchronous_execution(previous)
|
||||||
|
|
||||||
def test_checkpointing(self):
|
def test_checkpointing(self):
|
||||||
|
self.skipTest(
|
||||||
|
"Disable saving until SaveableObject's methods are traceable.")
|
||||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
with self.device.scope():
|
with self.device.scope():
|
||||||
different_values = self.device.pack(
|
different_values = self.device.pack(
|
||||||
|
@ -263,6 +265,8 @@ class LayerTests(_VirtualDeviceTestCase):
|
||||||
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
|
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
|
||||||
|
|
||||||
def test_training_loop(self):
|
def test_training_loop(self):
|
||||||
|
self.skipTest(
|
||||||
|
"Disable saving until SaveableObject's methods are traceable.")
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
layer = _Dense(5)
|
layer = _Dense(5)
|
||||||
checkpoint = tracking.Checkpoint(layer=layer)
|
checkpoint = tracking.Checkpoint(layer=layer)
|
||||||
|
|
|
@ -100,7 +100,7 @@ _ORDER_INSENSITIVE_STATEFUL_OPS = [
|
||||||
"CudnnRNNV2", "CudnnRNNV3", "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
|
"CudnnRNNV2", "CudnnRNNV3", "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
|
||||||
"EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
|
"EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
|
||||||
"EnqueueTPUEmbeddingSparseTensorBatch",
|
"EnqueueTPUEmbeddingSparseTensorBatch",
|
||||||
"EnqueueTPUEmbeddingRaggedTensorBatch"
|
"EnqueueTPUEmbeddingRaggedTensorBatch", "RestoreV2", "SaveV2"
|
||||||
]
|
]
|
||||||
# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc)
|
# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc)
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,7 @@ cuda_py_test(
|
||||||
":checkpoint_options",
|
":checkpoint_options",
|
||||||
":functional_saver",
|
":functional_saver",
|
||||||
":saveable_hook",
|
":saveable_hook",
|
||||||
|
"//tensorflow/python/eager:remote",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from tensorflow.core.protobuf import saver_pb2
|
from tensorflow.core.protobuf import saver_pb2
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
@ -161,7 +162,8 @@ class MultiDeviceSaver(object):
|
||||||
self._after_restore_callbacks.append(saveable.after_restore)
|
self._after_restore_callbacks.append(saveable.after_restore)
|
||||||
|
|
||||||
if is_saveable:
|
if is_saveable:
|
||||||
saveables_by_device.setdefault(saveable.device, []).append(saveable)
|
host_device = saveable_object_util.set_cpu0(saveable.device)
|
||||||
|
saveables_by_device.setdefault(host_device, []).append(saveable)
|
||||||
|
|
||||||
self._single_device_savers = {
|
self._single_device_savers = {
|
||||||
device: _SingleDeviceSaver(saveables)
|
device: _SingleDeviceSaver(saveables)
|
||||||
|
@ -247,33 +249,50 @@ class MultiDeviceSaver(object):
|
||||||
tmp_checkpoint_prefix = string_ops.string_join(
|
tmp_checkpoint_prefix = string_ops.string_join(
|
||||||
[file_prefix, sharded_suffix])
|
[file_prefix, sharded_suffix])
|
||||||
|
|
||||||
num_shards = len(self._single_device_savers)
|
def save_fn():
|
||||||
sharded_saves = []
|
num_shards = len(self._single_device_savers)
|
||||||
sharded_prefixes = []
|
sharded_saves = []
|
||||||
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
|
sharded_prefixes = []
|
||||||
last_device = None
|
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
|
||||||
for shard, (device, saver) in enumerate(
|
last_device = None
|
||||||
sorted(self._single_device_savers.items())):
|
for shard, (device, saver) in enumerate(
|
||||||
last_device = device
|
sorted(self._single_device_savers.items())):
|
||||||
with ops.device(saveable_object_util.set_cpu0(device)):
|
last_device = device
|
||||||
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
|
with ops.device(saveable_object_util.set_cpu0(device)):
|
||||||
num_shards_tensor)
|
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
|
||||||
sharded_prefixes.append(shard_prefix)
|
num_shards_tensor)
|
||||||
with ops.device(device):
|
sharded_prefixes.append(shard_prefix)
|
||||||
# _SingleDeviceSaver will use the CPU device when necessary, but initial
|
with ops.device(device):
|
||||||
# read operations should be placed on the SaveableObject's device.
|
# _SingleDeviceSaver will use the CPU device when necessary, but
|
||||||
sharded_saves.append(saver.save(shard_prefix, options))
|
# initial read operations should be placed on the SaveableObject's
|
||||||
|
# device.
|
||||||
|
sharded_saves.append(saver.save(shard_prefix, options))
|
||||||
|
|
||||||
with ops.control_dependencies(sharded_saves):
|
with ops.control_dependencies(sharded_saves):
|
||||||
# Merge on the io_device if specified, otherwise co-locates the merge op
|
# Merge on the io_device if specified, otherwise co-locates the merge op
|
||||||
# with the last device used.
|
# with the last device used.
|
||||||
merge_device = (options.experimental_io_device or
|
merge_device = (
|
||||||
saveable_object_util.set_cpu0(last_device))
|
options.experimental_io_device or
|
||||||
with ops.device(merge_device):
|
saveable_object_util.set_cpu0(last_device))
|
||||||
# V2 format write path consists of a metadata merge step. Once merged,
|
with ops.device(merge_device):
|
||||||
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
|
# V2 format write path consists of a metadata merge step. Once
|
||||||
return gen_io_ops.merge_v2_checkpoints(
|
# merged, attempts to delete the temporary directory,
|
||||||
sharded_prefixes, file_prefix, delete_old_dirs=True)
|
# "<user-fed prefix>_temp".
|
||||||
|
return gen_io_ops.merge_v2_checkpoints(
|
||||||
|
sharded_prefixes, file_prefix, delete_old_dirs=True)
|
||||||
|
|
||||||
|
# Since this will causes a function re-trace on each save, limit this to the
|
||||||
|
# cases where it is needed: eager and when there are multiple tasks/single
|
||||||
|
# device savers. Note that the retrace is needed to ensure we pickup the
|
||||||
|
# latest values of options like experimental_io_device.
|
||||||
|
if context.executing_eagerly() and len(self._single_device_savers) > 1:
|
||||||
|
# Explicitly place the identity op on the first device.
|
||||||
|
@def_function.function(experimental_compile=False)
|
||||||
|
def tf_function_save():
|
||||||
|
save_fn()
|
||||||
|
tf_function_save()
|
||||||
|
else:
|
||||||
|
return save_fn()
|
||||||
|
|
||||||
def restore(self, file_prefix, options=None):
|
def restore(self, file_prefix, options=None):
|
||||||
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
"""Restore the saveable objects from a checkpoint with `file_prefix`.
|
||||||
|
@ -287,12 +306,38 @@ class MultiDeviceSaver(object):
|
||||||
A dictionary mapping from SaveableObject names to restore operations.
|
A dictionary mapping from SaveableObject names to restore operations.
|
||||||
"""
|
"""
|
||||||
options = options or checkpoint_options.CheckpointOptions()
|
options = options or checkpoint_options.CheckpointOptions()
|
||||||
restore_ops = {}
|
|
||||||
# Sort by device name to avoid propagating non-deterministic dictionary
|
def restore_fn():
|
||||||
# ordering in some Python versions.
|
restore_ops = {}
|
||||||
for device, saver in sorted(self._single_device_savers.items()):
|
# Sort by device name to avoid propagating non-deterministic dictionary
|
||||||
with ops.device(device):
|
# ordering in some Python versions.
|
||||||
restore_ops.update(saver.restore(file_prefix, options))
|
for device, saver in sorted(self._single_device_savers.items()):
|
||||||
|
with ops.device(device):
|
||||||
|
restore_ops.update(saver.restore(file_prefix, options))
|
||||||
|
|
||||||
|
return restore_ops
|
||||||
|
|
||||||
|
# Since this will causes a function re-trace on each save, limit this to the
|
||||||
|
# cases where it is needed: eager and when there are multiple tasks/single
|
||||||
|
# device savers. Note that the retrace is needed to ensure we pickup the
|
||||||
|
# latest values of options like experimental_io_device.
|
||||||
|
if context.executing_eagerly() and len(self._single_device_savers) > 1:
|
||||||
|
first_device, _ = list(self._single_device_savers.items())[0]
|
||||||
|
@def_function.function(experimental_compile=False)
|
||||||
|
def tf_function_restore():
|
||||||
|
restore_ops = restore_fn()
|
||||||
|
restore_tensors = {}
|
||||||
|
# tf.functions must return tensors, thus we use control dependencies so
|
||||||
|
# that we can return a tensor which depends on the given op.
|
||||||
|
with ops.device(saveable_object_util.set_cpu0(first_device)):
|
||||||
|
for name, op in restore_ops.items():
|
||||||
|
with ops.control_dependencies([op]):
|
||||||
|
restore_tensors[name] = array_ops.identity(file_prefix)
|
||||||
|
return restore_tensors
|
||||||
|
|
||||||
|
restore_ops = tf_function_restore()
|
||||||
|
else:
|
||||||
|
restore_ops = restore_fn()
|
||||||
|
|
||||||
for callback in self._after_restore_callbacks:
|
for callback in self._after_restore_callbacks:
|
||||||
callback()
|
callback()
|
||||||
|
|
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import remote
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.eager import wrap_function
|
from tensorflow.python.eager import wrap_function
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
|
@ -29,6 +30,7 @@ from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.training.saving import checkpoint_options
|
from tensorflow.python.training.saving import checkpoint_options
|
||||||
from tensorflow.python.training.saving import functional_saver
|
from tensorflow.python.training.saving import functional_saver
|
||||||
from tensorflow.python.training.saving import saveable_hook
|
from tensorflow.python.training.saving import saveable_hook
|
||||||
|
@ -126,13 +128,16 @@ class SaverTest(test.TestCase):
|
||||||
second_saver.restore(save_path)
|
second_saver.restore(save_path)
|
||||||
self.assertEqual(2., self.evaluate(v2))
|
self.assertEqual(2., self.evaluate(v2))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
def test_checkpoint_is_sharded_by_task(self):
|
||||||
def test_checkpoint_is_sharded_by_device(self):
|
servers = [server_lib.Server.create_local_server() for _ in range(3)]
|
||||||
with ops.device("cpu:0"):
|
cluster_spec = server_lib.ClusterSpec({
|
||||||
|
"worker": [s.target[len("grpc://"):] for s in servers]})
|
||||||
|
remote.connect_to_cluster(cluster_spec)
|
||||||
|
with ops.device("/job:worker/task:0/cpu:0"):
|
||||||
v0 = resource_variable_ops.ResourceVariable(0.)
|
v0 = resource_variable_ops.ResourceVariable(0.)
|
||||||
with ops.device("cpu:1"):
|
with ops.device("/job:worker/task:1/cpu:0"):
|
||||||
v1 = resource_variable_ops.ResourceVariable(1.)
|
v1 = resource_variable_ops.ResourceVariable(1.)
|
||||||
with ops.device("cpu:2"):
|
with ops.device("/job:worker/task:2/cpu:0"):
|
||||||
v2 = resource_variable_ops.ResourceVariable(2.)
|
v2 = resource_variable_ops.ResourceVariable(2.)
|
||||||
|
|
||||||
self.evaluate([v0.initializer, v1.initializer, v2.initializer])
|
self.evaluate([v0.initializer, v1.initializer, v2.initializer])
|
||||||
|
@ -167,7 +172,7 @@ class SaverTest(test.TestCase):
|
||||||
list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
|
list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
|
||||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||||
self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
|
self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
|
||||||
self.assertEqual(4, len(gfile.Glob(prefix + "*")))
|
self.assertEqual(2, len(gfile.Glob(prefix + "*")))
|
||||||
self.evaluate(v0.assign(-1.))
|
self.evaluate(v0.assign(-1.))
|
||||||
self.evaluate(v1.assign(-1.))
|
self.evaluate(v1.assign(-1.))
|
||||||
self.evaluate(v2.assign(-1.))
|
self.evaluate(v2.assign(-1.))
|
||||||
|
|
Loading…
Reference in New Issue