In all cases, we can't rely on the tensor.device attribute being set. So its better to get the device for a SaveSpec from the device passed in rather. This was an issue with saving iterators because for iterators the resource usually has a device specification but the serialized tensor derived from it might not have it set. As a result, when saving iterators in a sharded fashion all iterators end up on '' device instead which is not what is intended.

Also adding support for saving iterators in a sharded fashion to avoid unnecessary copying during checkpointing.

PiperOrigin-RevId: 286310419
Change-Id: I1a957af783f7f69753992ce220b59eb43df2c02f
This commit is contained in:
Rohan Jain 2019-12-18 19:07:50 -08:00 committed by TensorFlower Gardener
parent 59ee43d578
commit c19c8167c2
5 changed files with 146 additions and 7 deletions
tensorflow/python

View File

@ -173,8 +173,8 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
self._checkpoint_saver_hook._scaffold is None):
iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
saveables = [iterator_ops._IteratorSaveable(i, i.name) for i in iterators]
self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
self._latest_filename)
self._checkpoint_saver_hook._saver = _CustomSaver(
saveables, self._latest_filename, sharded=True)
# pylint: enable=protected-access
self._checkpoint_saver_hook.begin()
@ -238,8 +238,8 @@ class _CustomSaver(saver_lib.Saver):
the model ckpt saved by the `CheckpointSaverHook`.
"""
def __init__(self, var_list, latest_filename):
super(_CustomSaver, self).__init__(var_list)
def __init__(self, var_list, latest_filename, sharded=False):
super(_CustomSaver, self).__init__(var_list, sharded=sharded)
self._latest_filename = latest_filename
def save(self,

View File

@ -796,7 +796,11 @@ class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
def __init__(self, iterator_resource, name):
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
specs = [
BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
BaseSaverBuilder.SaveSpec(
serialized_iterator,
"",
name + "_STATE",
device=iterator_resource.device)
]
super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)

View File

@ -401,7 +401,7 @@ class BaseSaverBuilder(object):
per_device = collections.defaultdict(lambda: [])
for saveable in saveables:
canonical_device = set(
pydev.canonical_name(spec.tensor.device) for spec in saveable.specs)
pydev.canonical_name(spec.device) for spec in saveable.specs)
if len(canonical_device) != 1:
raise ValueError("All tensors of a saveable object must be "
"on the same device: %s" % saveable.name)

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import functools
import glob
import math
import os
import random
@ -36,6 +37,7 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -1108,6 +1110,136 @@ class SaveRestoreShardedTest(test.TestCase):
class SaveRestoreShardedTestV2(SaveRestoreShardedTest):
_WRITE_VERSION = saver_pb2.SaverDef.V2
def testIterators(self):
save_path = os.path.join(self.get_temp_dir(), "sharded_iterators")
# Build a graph with 2 parameter nodes on different devices and save.
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
ds0 = dataset_ops.Dataset.range(10)
it0 = dataset_ops.make_initializable_iterator(ds0)
get_next0 = it0.get_next()
saveable0 = iterator_ops._IteratorSaveable(
it0._iterator_resource, name="saveable_it0")
with sess.graph.device("/cpu:1"):
ds1 = dataset_ops.Dataset.range(20)
it1 = dataset_ops.make_initializable_iterator(ds1)
get_next1 = it1.get_next()
saveable1 = iterator_ops._IteratorSaveable(
it1._iterator_resource, name="saveable_it1")
saver = saver_module.Saver({
"it0": saveable0,
"it1": saveable1
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(it0.initializer)
self.evaluate(it1.initializer)
self.assertEqual(0, self.evaluate(get_next0))
self.assertEqual(1, self.evaluate(get_next0))
self.assertEqual(0, self.evaluate(get_next1))
val = saver.save(sess, save_path)
self.assertEqual(save_path, val)
data_files = glob.glob(save_path + ".data*")
self.assertEqual(2, len(data_files))
# Restore
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
ds0 = dataset_ops.Dataset.range(10)
it0 = dataset_ops.make_initializable_iterator(ds0)
get_next0 = it0.get_next()
saveable0 = iterator_ops._IteratorSaveable(
it0._iterator_resource, name="saveable_it0")
with sess.graph.device("/cpu:1"):
ds1 = dataset_ops.Dataset.range(20)
it1 = dataset_ops.make_initializable_iterator(ds1)
get_next1 = it1.get_next()
saveable1 = iterator_ops._IteratorSaveable(
it1._iterator_resource, name="saveable_it1")
saver = saver_module.Saver({
"it0": saveable0,
"it1": saveable1
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(it0.initializer)
self.evaluate(it1.initializer)
saver.restore(sess, save_path)
self.assertEqual(2, self.evaluate(get_next0))
self.assertEqual(1, self.evaluate(get_next1))
def testIteratorsUnshardedRestore(self):
save_path = os.path.join(self.get_temp_dir(), "restore_unsharded_iterators")
# Build a graph with 2 parameter nodes on different devices and save.
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
ds0 = dataset_ops.Dataset.range(10)
it0 = dataset_ops.make_initializable_iterator(ds0)
get_next0 = it0.get_next()
saveable0 = iterator_ops._IteratorSaveable(
it0._iterator_resource, name="saveable_it0")
with sess.graph.device("/cpu:1"):
ds1 = dataset_ops.Dataset.range(20)
it1 = dataset_ops.make_initializable_iterator(ds1)
get_next1 = it1.get_next()
saveable1 = iterator_ops._IteratorSaveable(
it1._iterator_resource, name="saveable_it1")
saver = saver_module.Saver({
"it0": saveable0,
"it1": saveable1
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(it0.initializer)
self.evaluate(it1.initializer)
self.assertEqual(0, self.evaluate(get_next0))
self.assertEqual(1, self.evaluate(get_next0))
self.assertEqual(0, self.evaluate(get_next1))
val = saver.save(sess, save_path)
self.assertEqual(save_path, val)
data_files = glob.glob(save_path + ".data*")
self.assertEqual(2, len(data_files))
# Restore
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
ds0 = dataset_ops.Dataset.range(10)
it0 = dataset_ops.make_initializable_iterator(ds0)
get_next0 = it0.get_next()
saveable0 = iterator_ops._IteratorSaveable(
it0._iterator_resource, name="saveable_it0")
with sess.graph.device("/cpu:1"):
ds1 = dataset_ops.Dataset.range(20)
it1 = dataset_ops.make_initializable_iterator(ds1)
get_next1 = it1.get_next()
saveable1 = iterator_ops._IteratorSaveable(
it1._iterator_resource, name="saveable_it1")
saver = saver_module.Saver({
"it0": saveable0,
"it1": saveable1
},
write_version=self._WRITE_VERSION,
sharded=False)
self.evaluate(it0.initializer)
self.evaluate(it1.initializer)
saver.restore(sess, save_path)
self.assertEqual(2, self.evaluate(get_next0))
self.assertEqual(1, self.evaluate(get_next1))
class MaxToKeepTest(test.TestCase):

View File

@ -45,7 +45,10 @@ class SaveSpec(object):
self.device = device
else:
self.dtype = tensor.dtype
self.device = tensor.device
if device is not None:
self.device = device
else:
self.device = tensor.device
@property
def tensor(self):