In all cases, we can't rely on the tensor.device attribute being set. So its better to get the device for a SaveSpec from the device passed in rather. This was an issue with saving iterators because for iterators the resource usually has a device specification but the serialized tensor derived from it might not have it set. As a result, when saving iterators in a sharded fashion all iterators end up on '' device instead which is not what is intended.
Also adding support for saving iterators in a sharded fashion to avoid unnecessary copying during checkpointing. PiperOrigin-RevId: 286310419 Change-Id: I1a957af783f7f69753992ce220b59eb43df2c02f
This commit is contained in:
parent
59ee43d578
commit
c19c8167c2
tensorflow/python
data
training
@ -173,8 +173,8 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
|
||||
self._checkpoint_saver_hook._scaffold is None):
|
||||
iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
|
||||
saveables = [iterator_ops._IteratorSaveable(i, i.name) for i in iterators]
|
||||
self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
|
||||
self._latest_filename)
|
||||
self._checkpoint_saver_hook._saver = _CustomSaver(
|
||||
saveables, self._latest_filename, sharded=True)
|
||||
# pylint: enable=protected-access
|
||||
self._checkpoint_saver_hook.begin()
|
||||
|
||||
@ -238,8 +238,8 @@ class _CustomSaver(saver_lib.Saver):
|
||||
the model ckpt saved by the `CheckpointSaverHook`.
|
||||
"""
|
||||
|
||||
def __init__(self, var_list, latest_filename):
|
||||
super(_CustomSaver, self).__init__(var_list)
|
||||
def __init__(self, var_list, latest_filename, sharded=False):
|
||||
super(_CustomSaver, self).__init__(var_list, sharded=sharded)
|
||||
self._latest_filename = latest_filename
|
||||
|
||||
def save(self,
|
||||
|
@ -796,7 +796,11 @@ class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
|
||||
def __init__(self, iterator_resource, name):
|
||||
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
|
||||
specs = [
|
||||
BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
|
||||
BaseSaverBuilder.SaveSpec(
|
||||
serialized_iterator,
|
||||
"",
|
||||
name + "_STATE",
|
||||
device=iterator_resource.device)
|
||||
]
|
||||
super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
|
||||
|
||||
|
@ -401,7 +401,7 @@ class BaseSaverBuilder(object):
|
||||
per_device = collections.defaultdict(lambda: [])
|
||||
for saveable in saveables:
|
||||
canonical_device = set(
|
||||
pydev.canonical_name(spec.tensor.device) for spec in saveable.specs)
|
||||
pydev.canonical_name(spec.device) for spec in saveable.specs)
|
||||
if len(canonical_device) != 1:
|
||||
raise ValueError("All tensors of a saveable object must be "
|
||||
"on the same device: %s" % saveable.name)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
@ -36,6 +37,7 @@ from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -1108,6 +1110,136 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
class SaveRestoreShardedTestV2(SaveRestoreShardedTest):
|
||||
_WRITE_VERSION = saver_pb2.SaverDef.V2
|
||||
|
||||
def testIterators(self):
|
||||
save_path = os.path.join(self.get_temp_dir(), "sharded_iterators")
|
||||
|
||||
# Build a graph with 2 parameter nodes on different devices and save.
|
||||
with session.Session(
|
||||
target="",
|
||||
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
|
||||
with sess.graph.device("/cpu:0"):
|
||||
ds0 = dataset_ops.Dataset.range(10)
|
||||
it0 = dataset_ops.make_initializable_iterator(ds0)
|
||||
get_next0 = it0.get_next()
|
||||
saveable0 = iterator_ops._IteratorSaveable(
|
||||
it0._iterator_resource, name="saveable_it0")
|
||||
|
||||
with sess.graph.device("/cpu:1"):
|
||||
ds1 = dataset_ops.Dataset.range(20)
|
||||
it1 = dataset_ops.make_initializable_iterator(ds1)
|
||||
get_next1 = it1.get_next()
|
||||
saveable1 = iterator_ops._IteratorSaveable(
|
||||
it1._iterator_resource, name="saveable_it1")
|
||||
saver = saver_module.Saver({
|
||||
"it0": saveable0,
|
||||
"it1": saveable1
|
||||
},
|
||||
write_version=self._WRITE_VERSION,
|
||||
sharded=True)
|
||||
self.evaluate(it0.initializer)
|
||||
self.evaluate(it1.initializer)
|
||||
self.assertEqual(0, self.evaluate(get_next0))
|
||||
self.assertEqual(1, self.evaluate(get_next0))
|
||||
self.assertEqual(0, self.evaluate(get_next1))
|
||||
val = saver.save(sess, save_path)
|
||||
self.assertEqual(save_path, val)
|
||||
data_files = glob.glob(save_path + ".data*")
|
||||
self.assertEqual(2, len(data_files))
|
||||
|
||||
# Restore
|
||||
with session.Session(
|
||||
target="",
|
||||
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
|
||||
with sess.graph.device("/cpu:0"):
|
||||
ds0 = dataset_ops.Dataset.range(10)
|
||||
it0 = dataset_ops.make_initializable_iterator(ds0)
|
||||
get_next0 = it0.get_next()
|
||||
saveable0 = iterator_ops._IteratorSaveable(
|
||||
it0._iterator_resource, name="saveable_it0")
|
||||
|
||||
with sess.graph.device("/cpu:1"):
|
||||
ds1 = dataset_ops.Dataset.range(20)
|
||||
it1 = dataset_ops.make_initializable_iterator(ds1)
|
||||
get_next1 = it1.get_next()
|
||||
saveable1 = iterator_ops._IteratorSaveable(
|
||||
it1._iterator_resource, name="saveable_it1")
|
||||
saver = saver_module.Saver({
|
||||
"it0": saveable0,
|
||||
"it1": saveable1
|
||||
},
|
||||
write_version=self._WRITE_VERSION,
|
||||
sharded=True)
|
||||
self.evaluate(it0.initializer)
|
||||
self.evaluate(it1.initializer)
|
||||
saver.restore(sess, save_path)
|
||||
self.assertEqual(2, self.evaluate(get_next0))
|
||||
self.assertEqual(1, self.evaluate(get_next1))
|
||||
|
||||
def testIteratorsUnshardedRestore(self):
|
||||
save_path = os.path.join(self.get_temp_dir(), "restore_unsharded_iterators")
|
||||
|
||||
# Build a graph with 2 parameter nodes on different devices and save.
|
||||
with session.Session(
|
||||
target="",
|
||||
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
|
||||
with sess.graph.device("/cpu:0"):
|
||||
ds0 = dataset_ops.Dataset.range(10)
|
||||
it0 = dataset_ops.make_initializable_iterator(ds0)
|
||||
get_next0 = it0.get_next()
|
||||
saveable0 = iterator_ops._IteratorSaveable(
|
||||
it0._iterator_resource, name="saveable_it0")
|
||||
|
||||
with sess.graph.device("/cpu:1"):
|
||||
ds1 = dataset_ops.Dataset.range(20)
|
||||
it1 = dataset_ops.make_initializable_iterator(ds1)
|
||||
get_next1 = it1.get_next()
|
||||
saveable1 = iterator_ops._IteratorSaveable(
|
||||
it1._iterator_resource, name="saveable_it1")
|
||||
saver = saver_module.Saver({
|
||||
"it0": saveable0,
|
||||
"it1": saveable1
|
||||
},
|
||||
write_version=self._WRITE_VERSION,
|
||||
sharded=True)
|
||||
self.evaluate(it0.initializer)
|
||||
self.evaluate(it1.initializer)
|
||||
self.assertEqual(0, self.evaluate(get_next0))
|
||||
self.assertEqual(1, self.evaluate(get_next0))
|
||||
self.assertEqual(0, self.evaluate(get_next1))
|
||||
val = saver.save(sess, save_path)
|
||||
self.assertEqual(save_path, val)
|
||||
data_files = glob.glob(save_path + ".data*")
|
||||
self.assertEqual(2, len(data_files))
|
||||
|
||||
# Restore
|
||||
with session.Session(
|
||||
target="",
|
||||
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
|
||||
with sess.graph.device("/cpu:0"):
|
||||
ds0 = dataset_ops.Dataset.range(10)
|
||||
it0 = dataset_ops.make_initializable_iterator(ds0)
|
||||
get_next0 = it0.get_next()
|
||||
saveable0 = iterator_ops._IteratorSaveable(
|
||||
it0._iterator_resource, name="saveable_it0")
|
||||
|
||||
with sess.graph.device("/cpu:1"):
|
||||
ds1 = dataset_ops.Dataset.range(20)
|
||||
it1 = dataset_ops.make_initializable_iterator(ds1)
|
||||
get_next1 = it1.get_next()
|
||||
saveable1 = iterator_ops._IteratorSaveable(
|
||||
it1._iterator_resource, name="saveable_it1")
|
||||
saver = saver_module.Saver({
|
||||
"it0": saveable0,
|
||||
"it1": saveable1
|
||||
},
|
||||
write_version=self._WRITE_VERSION,
|
||||
sharded=False)
|
||||
self.evaluate(it0.initializer)
|
||||
self.evaluate(it1.initializer)
|
||||
saver.restore(sess, save_path)
|
||||
self.assertEqual(2, self.evaluate(get_next0))
|
||||
self.assertEqual(1, self.evaluate(get_next1))
|
||||
|
||||
|
||||
class MaxToKeepTest(test.TestCase):
|
||||
|
||||
|
@ -45,7 +45,10 @@ class SaveSpec(object):
|
||||
self.device = device
|
||||
else:
|
||||
self.dtype = tensor.dtype
|
||||
self.device = tensor.device
|
||||
if device is not None:
|
||||
self.device = device
|
||||
else:
|
||||
self.device = tensor.device
|
||||
|
||||
@property
|
||||
def tensor(self):
|
||||
|
Loading…
Reference in New Issue
Block a user