STT-tensorflow/tensorflow/python/training/saver_test.py
Yanhui Liang 574239c71b Update run_v1_only tests in saver_test with proper reasons.
PiperOrigin-RevId: 322471183
Change-Id: I2e6896a382d9a8996a766388ee6b8ce718eadacc
2020-07-21 17:10:15 -07:00

3148 lines
131 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for tensorflow.python.training.saver.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import math
import os
import random
import time
import numpy as np
import six
from google.protobuf.any_pb2 import Any
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import queue_runner_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import function
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import adam
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
from tensorflow.python.training.tracking import base as trackable_base
from tensorflow.python.util import compat
class SaverTest(test.TestCase):
def basicSaveRestore(self, variable_op):
save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
with self.session(graph=ops_lib.Graph()) as sess:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
v0 = variable_op(10.0, name="v0")
v1 = variable_op(20.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
# Initialize all variables
if not context.executing_eagerly():
self.evaluate([variables.global_variables_initializer(), v2_init])
# Check that the parameter nodes have been initialized.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
self.assertEqual(b"k1", self.evaluate(v2.keys()))
self.assertEqual(30.0, self.evaluate(v2.values()))
# Save the initialized values in the file at "save_path"
save = saver_module.Saver(
{
"v0": v0,
"v1": v1,
"v2": v2.saveable
}, restore_sequentially=True)
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
# Start a second session. In that session the parameter nodes
# have not been initialized either.
with self.session(graph=ops_lib.Graph()) as sess:
v0 = variable_op(-1.0, name="v0")
v1 = variable_op(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Assert that the variables are not initialized.
if not context.executing_eagerly():
self.assertEqual(
len(variables.report_uninitialized_variables().eval()), 2)
self.assertEqual(0, len(self.evaluate(v2.keys())))
self.assertEqual(0, len(self.evaluate(v2.values())))
# Restore the saved values in the parameter nodes.
save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
self.assertEqual(b"k1", self.evaluate(v2.keys()))
self.assertEqual(30.0, self.evaluate(v2.values()))
# Build another graph with 2 nodes, initialized
# differently, and a Restore node for them.
with self.session(graph=ops_lib.Graph()) as sess:
v0_2 = variable_op(1000.0, name="v0")
v1_2 = variable_op(2000.0, name="v1")
v2_2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2_2.insert("k1000", 3000.0)
# Check that the parameter nodes have been initialized.
if not context.executing_eagerly():
init_all_op = [variables.global_variables_initializer(), v2_init]
self.evaluate(init_all_op)
# TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty
# table as it claims in eager mode?
self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
self.assertEqual(3000.0, self.evaluate(v2_2.values()))
self.assertEqual(1000.0, self.evaluate(v0_2))
self.assertEqual(2000.0, self.evaluate(v1_2))
# Restore the values saved earlier in the parameter nodes.
save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
save2.restore(sess, save_path)
# Check that the parameter nodes have been restored.
self.assertEqual(10.0, self.evaluate(v0_2))
self.assertEqual(20.0, self.evaluate(v1_2))
self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
self.assertEqual(30.0, self.evaluate(v2_2.values()))
def testBasic(self):
self.basicSaveRestore(variables.Variable)
@test_util.run_in_graph_and_eager_modes
def testResourceBasic(self):
self.basicSaveRestore(resource_variable_ops.ResourceVariable)
def testResourceColocation(self):
# train.Saver is V1 only API.
with ops_lib.Graph().as_default():
partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2)
with ops_lib.device("/job:ps/device:GPU:0"):
v = variable_scope.get_variable(
"v0", shape=[10, 2], partitioner=partitioner, use_resource=True)
saver_module.Saver({"v0": v}).build()
save_op = None
for op in ops_lib.get_default_graph().get_operations():
if op.type == "SaveV2":
save_op = op
break
assert save_op is not None
for save_inp in save_op.inputs[3:]:
# Input to SaveV2 op is placed on CPU of the same device as
# the Variable.
self.assertEqual("/job:ps/device:CPU:0", save_inp.device)
def testResourceVariableReadOpsAddedDeterministically(self):
graph_defs = []
num_graphs = 10
for _ in range(num_graphs):
with ops_lib.Graph().as_default() as g:
for i in range(20):
resource_variable_ops.ResourceVariable(i, name="var%s" % i)
saver_module.Saver()
graph_defs.append(g.as_graph_def())
for i in range(num_graphs - 1):
self.assertEqual(graph_defs[i], graph_defs[i + 1])
def testEagerBasic(self):
with context.eager_mode():
ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")
v1 = resource_variable_ops.ResourceVariable(3.14, name="v1")
v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2")
save = saver_module.Saver([v1, v2])
save.save(None, ckpt_prefix)
v1.assign(0.0)
v2.assign([0, 0])
self.assertNear(0.0, self.evaluate(v1), 1e-5)
self.assertAllEqual([0, 0], self.evaluate(v2))
save.restore(None, ckpt_prefix)
self.assertNear(3.14, self.evaluate(v1), 1e-5)
self.assertAllEqual([1, 2], self.evaluate(v2))
def testEagerGraphCompatibility(self):
# Save from graph mode and restore from eager mode.
graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt")
with context.graph_mode():
with self.session(graph=ops_lib.Graph()) as sess:
# Create a graph model and save the checkpoint.
w1 = resource_variable_ops.ResourceVariable(1.0, name="w1")
w2 = resource_variable_ops.ResourceVariable(2.0, name="w2")
graph_saver = saver_module.Saver([w1, w2])
self.evaluate(variables.global_variables_initializer())
graph_saver.save(sess, graph_ckpt_prefix)
with context.eager_mode():
ops_lib._default_graph_stack.reset() # pylint: disable=protected-access
ops_lib.reset_default_graph()
w1 = resource_variable_ops.ResourceVariable(0.0, name="w1")
w2 = resource_variable_ops.ResourceVariable(0.0, name="w2")
graph_saver = saver_module.Saver([w1, w2])
graph_saver.restore(None, graph_ckpt_prefix)
self.assertAllEqual(self.evaluate(w1), 1.0)
self.assertAllEqual(self.evaluate(w2), 2.0)
# Save from eager mode and restore from graph mode.
eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt")
with context.eager_mode():
ops_lib._default_graph_stack.reset() # pylint: disable=protected-access
ops_lib.reset_default_graph()
w3 = resource_variable_ops.ResourceVariable(3.0, name="w3")
w4 = resource_variable_ops.ResourceVariable(4.0, name="w4")
graph_saver = saver_module.Saver([w3, w4])
graph_saver.save(None, eager_ckpt_prefix)
with context.graph_mode():
with self.session(graph=ops_lib.Graph()) as sess:
w3 = resource_variable_ops.ResourceVariable(0.0, name="w3")
w4 = resource_variable_ops.ResourceVariable(0.0, name="w4")
graph_saver = saver_module.Saver([w3, w4])
self.evaluate(variables.global_variables_initializer())
graph_saver.restore(sess, eager_ckpt_prefix)
self.assertAllEqual(w3, 3.0)
self.assertAllEqual(w4, 4.0)
@test_util.run_in_graph_and_eager_modes
def testResourceSaveRestoreCachingDevice(self):
save_path = os.path.join(self.get_temp_dir(), "resource_cache")
with self.session(graph=ops_lib.Graph()) as sess:
v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0",
name="v")
if context.executing_eagerly():
sess = None
else:
self.evaluate(variables.global_variables_initializer())
save = saver_module.Saver([v])
save.save(sess, save_path)
save2 = saver_module.Saver([v])
save2.restore(sess, save_path)
self.assertEqual(self.evaluate(v), [1])
def testNoAdditionalOpsAddedBySaverForResourceVariablesOutsideSaveScope(self):
with ops_lib.Graph().as_default() as g:
v = resource_variable_ops.ResourceVariable(1.0, name="v")
with ops_lib.name_scope("saver1"):
saver_module.Saver()
with ops_lib.name_scope("saver2"):
saver_module.Saver({"name": v})
ops_in_saver1_scope_but_not_save_scope = [
op for op in g.get_operations()
if (op.name.startswith("saver1/") and
not op.name.startswith("saver1/save/"))]
self.assertEqual(ops_in_saver1_scope_but_not_save_scope, [])
ops_in_saver2_scope_but_not_save_scope = [
op for op in g.get_operations()
if (op.name.startswith("saver2/") and
not op.name.startswith("saver2/save/"))]
self.assertEqual(ops_in_saver2_scope_but_not_save_scope, [])
def testSaveCopyRestoreWithSaveRelativePaths(self):
"""Save, copy checkpoint dir and restore from copied dir.
This only works for save_relative_paths=True.
"""
save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1")
os.mkdir(save_dir1)
save_path1 = os.path.join(save_dir1, "save_copy_restore")
# train.Saver is V1 only API.
with ops_lib.Graph().as_default():
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
v0 = variables.VariableV1(10.0, name="v0")
v1 = variables.VariableV1(20.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver(
var_list={
"v0": v0,
"v1": v1,
"v2": v2.saveable
},
restore_sequentially=True,
save_relative_paths=True)
init_all_op = [variables.global_variables_initializer(), v2_init]
with self.cached_session() as sess:
# Initialize all variables
self.evaluate(init_all_op)
# Check that the parameter nodes have been initialized.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
self.assertEqual(b"k1", self.evaluate(v2.keys()))
self.assertEqual(30.0, self.evaluate(v2.values()))
# Save the initialized values in the file at "save_path"
val = save.save(sess, save_path1)
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path1, val)
self.assertEqual(
checkpoint_management.latest_checkpoint(save_dir1), save_path1)
save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")
os.renames(save_dir1, save_dir2)
save_path2 = os.path.join(save_dir2, "save_copy_restore")
self.assertEqual(
checkpoint_management.latest_checkpoint(save_dir2), save_path2)
# Start a second session. In that session the parameter nodes
# have not been initialized either.
with self.cached_session() as sess:
v0 = variables.VariableV1(-1.0, name="v0")
v1 = variables.VariableV1(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
# Assert that the variables are not initialized.
self.assertEqual(
len(variables.report_uninitialized_variables().eval()), 2)
self.assertEqual(0, len(self.evaluate(v2.keys())))
self.assertEqual(0, len(self.evaluate(v2.values())))
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path2)
# Check that the parameter nodes have been restored.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
self.assertEqual(b"k1", self.evaluate(v2.keys()))
self.assertEqual(30.0, self.evaluate(v2.values()))
def testFilenameTensor(self):
# train.Saver is V1 only API.
with ops_lib.Graph().as_default():
v0 = variables.VariableV1(0, name="v0")
filename = b"somerandomfilename"
save = saver_module.Saver({"v0": v0}, filename=filename)
with self.cached_session() as sess:
tensor = sess.graph.get_tensor_by_name(
save.saver_def.filename_tensor_name)
self.assertEqual(self.evaluate(tensor), filename)
def testInvalidPath(self):
v0 = variables.VariableV1(0, name="v0")
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
with self.cached_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
with self.assertRaisesRegex(
ValueError, "The passed save_path is not a valid checkpoint:"):
save.restore(sess, "invalid path")
@test_util.run_v1_only("train.Saver is V1 only API.")
def testInt64(self):
save_path = os.path.join(self.get_temp_dir(), "int64")
with self.cached_session() as sess:
# Build a graph with 1 node, and save and restore for them.
v = variables.VariableV1(np.int64(15), name="v")
save = saver_module.Saver({"v": v}, restore_sequentially=True)
self.evaluate(variables.global_variables_initializer())
# Save the initialized values in the file at "save_path"
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
with self.cached_session() as sess:
v = variables.VariableV1(np.int64(-1), name="v")
save = saver_module.Saver({"v": v})
with self.assertRaisesWithPredicateMatch(
errors_impl.OpError, lambda e: "uninitialized value v" in e.message):
self.evaluate(v)
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.
self.assertEqual(np.int64(15), self.evaluate(v))
def testSomeErrors(self):
with ops_lib.Graph().as_default():
v0 = variables.VariableV1([10.0], name="v0")
v1 = variables.VariableV1([20.0], name="v1")
v2 = variables.VariableV1([20.0], name="v2")
v2._set_save_slice_info(
variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
# By default the name used for "v2" will be "v1" and raise an error.
with self.assertRaisesRegex(ValueError, "same name: v1"):
saver_module.Saver([v0, v1, v2])
# The names are different and will work.
saver_module.Saver({"vee1": v1, "other": [v2]})
# Partitioned variables also cause name conflicts.
p_v1 = variable_scope.get_variable(
"p_v1",
shape=[4, 5],
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=2))
p_v2 = variable_scope.get_variable(
"p_v2",
shape=[4, 5],
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=2))
p_v2._name = "p_v1"
with self.assertRaisesRegex(ValueError, "same name: p_v1"):
saver_module.Saver([p_v1, p_v2])
def testSameName(self):
with ops_lib.Graph().as_default():
v0 = variables.VariableV1([10.0], name="v0")
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saving one variable under two names raises an error.
with self.assertRaisesRegex(
ValueError, "The same saveable will be restored with two names: v0"):
saver_module.Saver({"v0": v0, "v0too": v0})
# Ditto for custom saveables.
with self.assertRaisesRegex(
ValueError, "The same saveable will be restored with two names: v2"):
saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable})
# Verify non-duplicate names work.
saver_module.Saver({"v0": v0, "v2": v2.saveable})
@test_util.run_v1_only("train.Saver and VariableV1 are V1 only APIs.")
def testBasicsWithListOfVariables(self):
save_path = os.path.join(self.get_temp_dir(), "basics_with_list")
with self.session(graph=ops_lib.Graph()) as sess:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
v0 = variables.VariableV1(10.0, name="v0")
v1 = variables.VariableV1(20.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver([v0, v1, v2.saveable])
self.evaluate(variables.global_variables_initializer())
v2_init.run()
# Check that the parameter nodes have been initialized.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
self.assertEqual(b"k1", self.evaluate(v2.keys()))
self.assertEqual(30.0, self.evaluate(v2.values()))
# Save the initialized values in the file at "save_path"
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
# Start a second session. In that session the variables
# have not been initialized either.
with self.session(graph=ops_lib.Graph()) as sess:
v0 = variables.VariableV1(-1.0, name="v0")
v1 = variables.VariableV1(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver([v0, v1, v2.saveable])
with self.assertRaisesWithPredicateMatch(
errors_impl.OpError, lambda e: "uninitialized value v0" in e.message):
self.evaluate(v0)
with self.assertRaisesWithPredicateMatch(
errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
self.evaluate(v1)
self.assertEqual(0, len(self.evaluate(v2.keys())))
self.assertEqual(0, len(self.evaluate(v2.values())))
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
self.assertEqual(b"k1", self.evaluate(v2.keys()))
self.assertEqual(30.0, self.evaluate(v2.values()))
# Build another graph with 2 nodes, initialized
# differently, and a Restore node for them.
with self.session(graph=ops_lib.Graph()) as sess:
v0_2 = variables.VariableV1(1000.0, name="v0")
v1_2 = variables.VariableV1(2000.0, name="v1")
v2_2 = saver_test_utils.CheckpointedOp(name="v2")
save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
v2_2.insert("k1000", 3000.0).run()
self.evaluate(variables.global_variables_initializer())
# Check that the parameter nodes have been initialized.
self.assertEqual(1000.0, self.evaluate(v0_2))
self.assertEqual(2000.0, self.evaluate(v1_2))
self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
self.assertEqual(3000.0, self.evaluate(v2_2.values()))
# Restore the values saved earlier in the parameter nodes.
save2.restore(sess, save_path)
# Check that the parameter nodes have been restored.
self.assertEqual(10.0, self.evaluate(v0_2))
self.assertEqual(20.0, self.evaluate(v1_2))
self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
self.assertEqual(30.0, self.evaluate(v2_2.values()))
def _SaveAndLoad(self, var_name, var_value, other_value, save_path):
with self.session(graph=ops_lib.Graph()) as sess:
var = resource_variable_ops.ResourceVariable(var_value, name=var_name)
save = saver_module.Saver({var_name: var})
if not context.executing_eagerly():
self.evaluate(var.initializer)
val = save.save(sess, save_path)
self.assertEqual(save_path, val)
with self.session(graph=ops_lib.Graph()) as sess:
var = resource_variable_ops.ResourceVariable(other_value, name=var_name)
save = saver_module.Saver({var_name: var})
save.restore(sess, save_path)
self.assertAllClose(var_value, self.evaluate(var))
def testCacheRereadsFile(self):
save_path = os.path.join(self.get_temp_dir(), "cache_rereads")
# Save and reload one Variable named "var0".
self._SaveAndLoad("var0", 0.0, 1.0, save_path)
# Save and reload one Variable named "var1" in the same file.
# The cached readers should know to re-read the file.
self._SaveAndLoad("var1", 1.1, 2.2, save_path)
def testAllowEmpty(self):
save_path = os.path.join(self.get_temp_dir(), "allow_empty")
# train.Saver is V1 only API.
with ops_lib.Graph().as_default(), self.cached_session() as sess:
_ = constant_op.constant(1)
save = saver_module.Saver(allow_empty=True)
val = save.save(sess, save_path)
self.assertIsNone(val)
with ops_lib.Graph().as_default(), self.cached_session() as sess:
save = saver_module.Saver(allow_empty=True)
save.restore(sess, save_path)
def testGPU(self):
if not test.is_gpu_available():
return
save_path = os.path.join(self.get_temp_dir(), "gpu")
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
v0_1 = variables.VariableV1(123.45)
save = saver_module.Saver({"v0": v0_1})
self.evaluate(variables.global_variables_initializer())
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
v0_2 = variables.VariableV1(543.21)
save = saver_module.Saver({"v0": v0_2})
self.evaluate(variables.global_variables_initializer())
def testSharedServerOnGPU(self):
if not test.is_gpu_available():
return
save_path = os.path.join(self.get_temp_dir(), "gpu")
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
v0_1 = variables.VariableV1(123.45)
save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True)
self.evaluate(variables.global_variables_initializer())
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
v0_2 = variables.VariableV1(543.21)
save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True)
self.evaluate(variables.global_variables_initializer())
def testVariables(self):
save_path = os.path.join(self.get_temp_dir(), "variables")
with session.Session("", graph=ops_lib.Graph()) as sess:
one = variables.VariableV1(1.0)
twos = variables.VariableV1([2.0, 2.0, 2.0])
v2 = saver_test_utils.CheckpointedOp(name="v2")
init = variables.global_variables_initializer()
save = saver_module.Saver()
init.run()
v2.insert("k1", 3.0).run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
one = variables.VariableV1(0.0)
twos = variables.VariableV1([0.0, 0.0, 0.0])
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
save.restore(sess, save_path)
self.assertAllClose(1.0, self.evaluate(one))
self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos))
self.assertEqual(b"k1", self.evaluate(v2.keys()))
self.assertEqual(3.0, self.evaluate(v2.values()))
def testVarListShouldBeEmptyInDeferredBuild(self):
with ops_lib.Graph().as_default():
v = variables.VariableV1(1.0)
with self.assertRaisesRegex(ValueError, "defer_build"):
saver_module.Saver([v], defer_build=True)
def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):
save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")
with ops_lib.Graph().as_default(), session.Session() as sess:
variables.VariableV1(1.0)
saver = saver_module.Saver(defer_build=True)
with self.assertRaisesRegex(RuntimeError, "build"):
saver.save(sess, save_path)
def testDeferredBuild(self):
save_path = os.path.join(self.get_temp_dir(), "deferred_build")
with session.Session("", graph=ops_lib.Graph()) as sess:
one = variables.VariableV1(1.0)
save = saver_module.Saver(defer_build=True)
# if build is not deferred, saver cannot save the `twos`.
twos = variables.VariableV1([2.0, 2.0, 2.0])
init = variables.global_variables_initializer()
save.build()
init.run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
one = variables.VariableV1(0.0)
twos = variables.VariableV1([0.0, 0.0, 0.0])
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
save.restore(sess, save_path)
self.assertAllClose(1.0, self.evaluate(one))
self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos))
@test_util.run_v1_only("train.Saver is V1 only API.")
def testReshape(self):
save_path = os.path.join(self.get_temp_dir(), "variables_reshape")
with session.Session("", graph=ops_lib.Graph()) as sess:
var = variables.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
init = variables.global_variables_initializer()
save = saver_module.Saver()
init.run()
save.save(sess, save_path)
# Error when restoring with default reshape=False
with session.Session("", graph=ops_lib.Graph()) as sess:
var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
save = saver_module.Saver()
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
"Assign requires shapes of both tensors to match."):
save.restore(sess, save_path)
# Restored to new shape with reshape=True
with session.Session("", graph=ops_lib.Graph()) as sess:
var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
save = saver_module.Saver(reshape=True)
save.restore(sess, save_path)
self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
self.evaluate(var))
@test_util.run_in_graph_and_eager_modes
def testSaveWithGlobalStep(self, pad_step_number=False):
save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")
global_step_int = 5
# Save and reload one Variable named "var0".
self._SaveAndLoad("var0", 0.0, 1.0, save_path)
for use_tensor in [True, False]:
with self.session(graph=ops_lib.Graph()):
var = resource_variable_ops.ResourceVariable(1.0, name="var0")
save = saver_module.Saver(
{
var._shared_name: var
}, pad_step_number=pad_step_number)
if context.executing_eagerly():
sess = None
else:
self.evaluate(var.initializer)
sess = ops_lib.get_default_session()
if use_tensor:
global_step = constant_op.constant(global_step_int)
val = save.save(sess, save_path, global_step=global_step)
else:
val = save.save(sess, save_path, global_step=global_step_int)
if pad_step_number:
expected_save_path = "%s-%s" % (save_path,
"{:08d}".format(global_step_int))
else:
expected_save_path = "%s-%d" % (save_path, global_step_int)
self.assertEqual(expected_save_path, val)
def testSaveWithGlobalStepWithPadding(self):
self.testSaveWithGlobalStep(pad_step_number=True)
def testSaveToNonexistingPath(self):
file_io.write_string_to_file(
os.path.join(self.get_temp_dir(), "actually_a_file"), "")
paths = [
os.path.join(self.get_temp_dir(), "nonexisting_dir/path"),
os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"),
os.path.join(self.get_temp_dir(), "actually_a_file/path"),
]
for save_path in paths:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
v0 = variables.VariableV1(10.0, name="v0")
v1 = variables.VariableV1(20.0, name="v1")
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
# In the case where the parent directory doesn't exist, whether or not the
# save succeeds or fails is implementation dependent. Therefore we allow
# both cases.
try:
with self.cached_session() as sess:
# Initialize all variables
self.evaluate(init_all_op)
# Check that the parameter nodes have been initialized.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
# Save the graph.
save.save(sess, save_path)
with self.cached_session() as sess:
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
except ValueError as exc:
error_msg_template = "Parent directory of {} doesn't exist, can't save."
self.assertEqual(error_msg_template.format(save_path), str(exc))
def testSaveToURI(self):
# ParseURI functions don't work on Windows yet.
# TODO(jhseu): Remove this check when it works.
if os.name == "nt":
self.skipTest("Local URI support doesn't work on Windows")
save_path = "file://" + os.path.join(self.get_temp_dir(), "uri")
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
v0 = variables.VariableV1(10.0, name="v0")
v1 = variables.VariableV1(20.0, name="v1")
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
with self.cached_session() as sess:
# Initialize all variables
self.evaluate(init_all_op)
# Check that the parameter nodes have been initialized.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
save.save(sess, save_path)
def testSaveRestoreAndValidateVariableDtype(self):
for variable_op in [
variables.Variable, resource_variable_ops.ResourceVariable
]:
save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
# Build the first session.
with self.session(graph=ops_lib.Graph()) as sess:
v0 = variable_op(10.0, name="v0", dtype=dtypes.float32)
if not context.executing_eagerly():
self.evaluate([variables.global_variables_initializer()])
save = saver_module.Saver({"v0": v0})
save.save(sess, save_path)
# Start a second session.
with self.session(graph=ops_lib.Graph()) as sess:
v0_wrong_dtype = variable_op(1, name="v0", dtype=dtypes.int32)
# Restore the saved value with different dtype
# in the parameter nodes.
save = saver_module.Saver({"v0": v0_wrong_dtype})
with self.assertRaisesRegex(errors.InvalidArgumentError,
"original dtype"):
save.restore(sess, save_path)
# Test restoring large tensors (triggers a thread pool)
def testRestoreLargeTensors(self):
save_dir = self.get_temp_dir()
def _model():
small_v = [variable_scope.get_variable(
"small%d" % i, shape=[10, 2], use_resource=True) for i in range(5)]
large_v = [variable_scope.get_variable(
"large%d" % i, shape=[32000, 1000], use_resource=True)
for i in range(3)]
return small_v + large_v
save_graph = ops_lib.Graph()
with save_graph.as_default(), self.session(graph=save_graph) as sess:
orig_vars = _model()
self.evaluate(variables.global_variables_initializer())
save = saver_module.Saver(max_to_keep=1)
self.evaluate(variables.global_variables_initializer())
save.save(sess, save_dir)
orig_vals = self.evaluate(orig_vars)
restore_graph = ops_lib.Graph()
with restore_graph.as_default(), self.session(
graph=restore_graph) as sess:
restored_vars = _model()
save = saver_module.Saver(max_to_keep=1)
save.restore(sess, save_dir)
restored_vals = self.evaluate(restored_vars)
for orig, restored in zip(orig_vals, restored_vals):
self.assertAllEqual(orig, restored)
class SaveRestoreShardedTest(test.TestCase):
_WRITE_VERSION = saver_pb2.SaverDef.V1
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
def testBasics(self):
save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
# Build a graph with 2 parameter nodes on different devices.
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.VariableV1(10, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
v1 = variables.VariableV1(20, name="v1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
"v0": v0,
"v1": v1,
"t0": t0.saveable,
"t1": t1.saveable
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(variables.global_variables_initializer())
t0.insert("k1", 30.0).run()
t1.insert("k2", 40.0).run()
val = save.save(sess, save_path)
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(save_path + "-?????-of-00002", val)
else:
self.assertEqual(save_path, val)
meta_graph_filename = checkpoint_management.meta_graph_filename(val)
self.assertEqual(save_path + ".meta", meta_graph_filename)
if save._write_version is saver_pb2.SaverDef.V1:
# Restore different ops from shard 0 of the saved files.
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.VariableV1(111, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
save = saver_module.Saver(
{
"v0": v0,
"t0": t0.saveable
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(variables.global_variables_initializer())
t0.insert("k11", 33.0).run()
self.assertEqual(111, self.evaluate(v0))
self.assertEqual(b"k11", self.evaluate(t0.keys()))
self.assertEqual(33.0, self.evaluate(t0.values()))
save.restore(sess, save_path + "-00000-of-00002")
self.assertEqual(10, self.evaluate(v0))
self.assertEqual(b"k1", self.evaluate(t0.keys()))
self.assertEqual(30.0, self.evaluate(t0.values()))
# Restore different ops from shard 1 of the saved files.
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v1 = variables.VariableV1(222)
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
"v1": v1,
"t1": t1.saveable
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(variables.global_variables_initializer())
t1.insert("k22", 44.0).run()
self.assertEqual(222, self.evaluate(v1))
self.assertEqual(b"k22", self.evaluate(t1.keys()))
self.assertEqual(44.0, self.evaluate(t1.values()))
save.restore(sess, save_path + "-00001-of-00002")
self.assertEqual(20, self.evaluate(v1))
self.assertEqual(b"k2", self.evaluate(t1.keys()))
self.assertEqual(40.0, self.evaluate(t1.values()))
# Now try a restore with the sharded filename.
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.VariableV1(111, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
v1 = variables.VariableV1(222, name="v1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
"v0": v0,
"v1": v1,
"t0": t0.saveable,
"t1": t1.saveable
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(variables.global_variables_initializer())
t0.insert("k11", 33.0).run()
t1.insert("k22", 44.0).run()
self.assertEqual(111, self.evaluate(v0))
self.assertEqual(222, self.evaluate(v1))
self.assertEqual(b"k11", self.evaluate(t0.keys()))
self.assertEqual(33.0, self.evaluate(t0.values()))
self.assertEqual(b"k22", self.evaluate(t1.keys()))
self.assertEqual(44.0, self.evaluate(t1.values()))
save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
if save._write_version is saver_pb2.SaverDef.V1:
save.restore(sess, save_path + "-?????-of-?????")
else:
save.restore(sess, save_path)
self.assertEqual(10, self.evaluate(v0))
self.assertEqual(20, self.evaluate(v1))
self.assertEqual(b"k1", self.evaluate(t0.keys()))
self.assertEqual(30.0, self.evaluate(t0.values()))
self.assertEqual(b"k2", self.evaluate(t1.keys()))
self.assertEqual(40.0, self.evaluate(t1.values()))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(
checkpoint_management.latest_checkpoint(self.get_temp_dir()),
os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))
else:
self.assertEqual(
checkpoint_management.latest_checkpoint(self.get_temp_dir()),
os.path.join(self.get_temp_dir(), "sharded_basics"))
def testSaverDef(self):
# train.Saver is V1 only API.
with ops_lib.Graph().as_default(), self.cached_session():
v0 = variables.VariableV1(123, name="v0")
save = saver_module.Saver({"v0": v0}, sharded=True)
sd = save.as_saver_def()
self.assertTrue(sd.sharded)
def _testPartitionedVariables(self, use_resource):
var_full_shape = [10, 3]
# Allows save/restore mechanism to work w/ different slicings.
var_name = "my_var"
saved_dir = self._get_test_dir("partitioned_variables")
saved_path = os.path.join(saved_dir, "ckpt")
call_saver_with_dict = False # updated by test loop below
def _save(partitioner=None):
# train.Saver is V1 only API.
with ops_lib.Graph().as_default(), self.session() as sess:
# Calls .eval() to return the ndarray that makes up the full variable.
rnd = random_ops.random_uniform(var_full_shape).eval()
if partitioner:
vs = [
variable_scope.get_variable(
var_name,
shape=var_full_shape,
initializer=rnd,
partitioner=partitioner,
use_resource=use_resource)
]
else:
if use_resource:
vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]
else:
vs = [variables.VariableV1(rnd, name=var_name)]
self.evaluate(variables.global_variables_initializer())
if call_saver_with_dict:
saver = saver_module.Saver({var_name: vs[0]})
else:
saver = saver_module.Saver(vs)
actual_path = saver.save(sess, saved_path)
self.assertEqual(saved_path, actual_path)
return rnd
def _restore(partitioner=None):
# train.Saver is V1 only API.
with ops_lib.Graph().as_default(), self.session() as sess:
if partitioner:
new_vs = [
variable_scope.get_variable(
var_name,
shape=var_full_shape,
initializer=array_ops.zeros(var_full_shape),
partitioner=partitioner)
]
else:
new_vs = [
variables.VariableV1(
array_ops.zeros(
shape=var_full_shape), # != original contents.
name=var_name)
]
self.evaluate(variables.global_variables_initializer())
if call_saver_with_dict:
saver = saver_module.Saver({
var_name: new_vs[0]
})
else:
saver = saver_module.Saver(new_vs)
saver.restore(sess, saved_path)
if partitioner:
return new_vs[0].as_tensor().eval()
else:
return new_vs[0].eval()
for call_saver_with_dict in {False, True}:
# Save PartitionedVariable and restore into full variable.
saved_full = _save(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=2))
restored_full = _restore()
self.assertAllEqual(saved_full, restored_full)
# Restores into the same number of partitions.
restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=2))
self.assertAllEqual(saved_full, restored_full)
# Restores into a different number of partitions.
restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=3))
self.assertAllEqual(saved_full, restored_full)
# Now, saves a full variable and restores PartitionedVariable.
saved_full = _save()
restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=3))
self.assertAllEqual(saved_full, restored_full)
def testPartitionedVariable(self):
self._testPartitionedVariables(use_resource=False)
def testPartitionedResourceVariable(self):
self._testPartitionedVariables(use_resource=True)
class SaveRestoreShardedTestV2(SaveRestoreShardedTest):
_WRITE_VERSION = saver_pb2.SaverDef.V2
def testIterators(self):
save_path = os.path.join(self.get_temp_dir(), "sharded_iterators")
# Build a graph with 2 parameter nodes on different devices and save.
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
ds0 = dataset_ops.Dataset.range(10)
it0 = dataset_ops.make_initializable_iterator(ds0)
get_next0 = it0.get_next()
saveable0 = iterator_ops._IteratorSaveable(
it0._iterator_resource, name="saveable_it0")
with sess.graph.device("/cpu:1"):
ds1 = dataset_ops.Dataset.range(20)
it1 = dataset_ops.make_initializable_iterator(ds1)
get_next1 = it1.get_next()
saveable1 = iterator_ops._IteratorSaveable(
it1._iterator_resource, name="saveable_it1")
saver = saver_module.Saver({
"it0": saveable0,
"it1": saveable1
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(it0.initializer)
self.evaluate(it1.initializer)
self.assertEqual(0, self.evaluate(get_next0))
self.assertEqual(1, self.evaluate(get_next0))
self.assertEqual(0, self.evaluate(get_next1))
val = saver.save(sess, save_path)
self.assertEqual(save_path, val)
data_files = glob.glob(save_path + ".data*")
self.assertEqual(2, len(data_files))
# Restore
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
ds0 = dataset_ops.Dataset.range(10)
it0 = dataset_ops.make_initializable_iterator(ds0)
get_next0 = it0.get_next()
saveable0 = iterator_ops._IteratorSaveable(
it0._iterator_resource, name="saveable_it0")
with sess.graph.device("/cpu:1"):
ds1 = dataset_ops.Dataset.range(20)
it1 = dataset_ops.make_initializable_iterator(ds1)
get_next1 = it1.get_next()
saveable1 = iterator_ops._IteratorSaveable(
it1._iterator_resource, name="saveable_it1")
saver = saver_module.Saver({
"it0": saveable0,
"it1": saveable1
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(it0.initializer)
self.evaluate(it1.initializer)
saver.restore(sess, save_path)
self.assertEqual(2, self.evaluate(get_next0))
self.assertEqual(1, self.evaluate(get_next1))
def testIteratorsUnshardedRestore(self):
save_path = os.path.join(self.get_temp_dir(), "restore_unsharded_iterators")
# Build a graph with 2 parameter nodes on different devices and save.
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
ds0 = dataset_ops.Dataset.range(10)
it0 = dataset_ops.make_initializable_iterator(ds0)
get_next0 = it0.get_next()
saveable0 = iterator_ops._IteratorSaveable(
it0._iterator_resource, name="saveable_it0")
with sess.graph.device("/cpu:1"):
ds1 = dataset_ops.Dataset.range(20)
it1 = dataset_ops.make_initializable_iterator(ds1)
get_next1 = it1.get_next()
saveable1 = iterator_ops._IteratorSaveable(
it1._iterator_resource, name="saveable_it1")
saver = saver_module.Saver({
"it0": saveable0,
"it1": saveable1
},
write_version=self._WRITE_VERSION,
sharded=True)
self.evaluate(it0.initializer)
self.evaluate(it1.initializer)
self.assertEqual(0, self.evaluate(get_next0))
self.assertEqual(1, self.evaluate(get_next0))
self.assertEqual(0, self.evaluate(get_next1))
val = saver.save(sess, save_path)
self.assertEqual(save_path, val)
data_files = glob.glob(save_path + ".data*")
self.assertEqual(2, len(data_files))
# Restore
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
ds0 = dataset_ops.Dataset.range(10)
it0 = dataset_ops.make_initializable_iterator(ds0)
get_next0 = it0.get_next()
saveable0 = iterator_ops._IteratorSaveable(
it0._iterator_resource, name="saveable_it0")
with sess.graph.device("/cpu:1"):
ds1 = dataset_ops.Dataset.range(20)
it1 = dataset_ops.make_initializable_iterator(ds1)
get_next1 = it1.get_next()
saveable1 = iterator_ops._IteratorSaveable(
it1._iterator_resource, name="saveable_it1")
saver = saver_module.Saver({
"it0": saveable0,
"it1": saveable1
},
write_version=self._WRITE_VERSION,
sharded=False)
self.evaluate(it0.initializer)
self.evaluate(it1.initializer)
saver.restore(sess, save_path)
self.assertEqual(2, self.evaluate(get_next0))
self.assertEqual(1, self.evaluate(get_next1))
class MaxToKeepTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
def assertCheckpointState(self, model_checkpoint_path,
all_model_checkpoint_paths, save_dir):
checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
self.assertEqual(checkpoint_state.model_checkpoint_path,
model_checkpoint_path)
self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
all_model_checkpoint_paths)
def testMaxToKeepEager(self):
with context.eager_mode():
save_dir = self._get_test_dir("max_to_keep_eager")
v = variable_scope.variable(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
self.evaluate(variables.global_variables_initializer())
if not context.executing_eagerly():
self.assertEqual([], save.last_checkpoints)
s1 = save.save(None, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s1],
save_dir=save_dir)
s2 = save.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s1, s2],
save_dir=save_dir)
s3 = save.save(None, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s2, s3],
save_dir=save_dir)
# Create a second helper, identical to the first.
save2 = saver_module.Saver({"v": v}, max_to_keep=2)
save2.set_last_checkpoints(save.last_checkpoints)
# Exercise the first helper.
# Adding s2 again (old s2 is removed first, then new s2 appended)
s2 = save.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save.last_checkpoints)
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
save_dir=save_dir)
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save.save(None, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save.last_checkpoints)
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
save_dir=save_dir)
s2 = save2.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save2.last_checkpoints)
# Created by the first helper.
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
# Deleted by the first helper.
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
def testNonSharded(self):
save_dir = self._get_test_dir("max_to_keep_non_sharded")
# train.Saver is V1 only API.
with ops_lib.Graph().as_default(), self.cached_session() as sess:
v = variables.VariableV1(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
self.evaluate(variables.global_variables_initializer())
self.assertEqual([], save.last_checkpoints)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s1],
save_dir=save_dir)
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s1, s2],
save_dir=save_dir)
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s2, s3],
save_dir=save_dir)
# Create a second helper, identical to the first.
save2 = saver_module.Saver(saver_def=save.as_saver_def())
save2.set_last_checkpoints(save.last_checkpoints)
# Create a third helper, with the same configuration but no knowledge of
# previous checkpoints.
save3 = saver_module.Saver(saver_def=save.as_saver_def())
# Exercise the first helper.
# Adding s2 again (old s2 is removed first, then new s2 appended)
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save.last_checkpoints)
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertFalse(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
save_dir=save_dir)
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save.last_checkpoints)
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
save_dir=save_dir)
# Exercise the second helper.
# Adding s2 again (old s2 is removed first, then new s2 appended)
s2 = save2.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save2.last_checkpoints)
# Created by the first helper.
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
# Deleted by the first helper.
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
save_dir=save_dir)
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save2.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save2.last_checkpoints)
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
save_dir=save_dir)
# Exercise the third helper.
# Adding s2 again (but helper is unaware of previous s2)
s2 = save3.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s2], save3.last_checkpoints)
# Created by the first helper.
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
# Deleted by the first helper.
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
# Even though the file for s1 exists, this saver isn't aware of it, which
# is why it doesn't end up in the checkpoint state.
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s2],
save_dir=save_dir)
# Adding s1 (s3 should not be deleted because helper is unaware of it)
s1 = save3.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save3.last_checkpoints)
self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s3)))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s2)))
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
checkpoint_management.checkpoint_exists(
checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
save_dir=save_dir)
def testSharded(self):
save_dir = self._get_test_dir("max_to_keep_sharded")
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.VariableV1(111, name="v0")
with sess.graph.device("/cpu:1"):
v1 = variables.VariableV1(222, name="v1")
save = saver_module.Saver(
{
"v0": v0,
"v1": v1
}, sharded=True, max_to_keep=2)
self.evaluate(variables.global_variables_initializer())
self.assertEqual([], save.last_checkpoints)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s1)))
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s1)))
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertEqual(0, len(gfile.Glob(s1 + "*")))
self.assertFalse(
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s3)))
else:
self.assertEqual(4, len(gfile.Glob(s3 + "*")))
self.assertTrue(
gfile.Exists(checkpoint_management.meta_graph_filename(s3)))
def testNoMaxToKeep(self):
save_dir = self._get_test_dir("no_max_to_keep")
save_dir2 = self._get_test_dir("max_to_keep_0")
with self.cached_session() as sess:
v = variables.VariableV1(10.0, name="v")
self.evaluate(variables.global_variables_initializer())
# Test max_to_keep being None.
save = saver_module.Saver({"v": v}, max_to_keep=None)
self.assertEqual([], save.last_checkpoints)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([], save.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([], save.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
# Test max_to_keep being 0.
save2 = saver_module.Saver({"v": v}, max_to_keep=0)
self.assertEqual([], save2.last_checkpoints)
s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
self.assertEqual([], save2.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
self.assertEqual([], save2.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
def testNoMetaGraph(self):
save_dir = self._get_test_dir("no_meta_graph")
with self.cached_session() as sess:
v = variables.VariableV1(10.0, name="v")
save = saver_module.Saver({"v": v})
self.evaluate(variables.global_variables_initializer())
s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertFalse(
gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
class RecoverLastCheckpointsTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
def assertCheckpointState(self, model_checkpoint_path,
all_model_checkpoint_paths, save_dir):
checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
self.assertEqual(checkpoint_state.model_checkpoint_path,
model_checkpoint_path)
self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
all_model_checkpoint_paths)
def test_recover_last_checkpoints(self):
with context.eager_mode():
save_dir = self._get_test_dir("recover_last_checkpoints")
v = variable_scope.variable(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=10)
self.evaluate(variables.global_variables_initializer())
self.assertEqual([], save.last_checkpoints)
s1 = save.save(None, os.path.join(save_dir, "ckpt-1"))
s2 = save.save(None, os.path.join(save_dir, "ckpt-2"))
s3 = save.save(None, os.path.join(save_dir, "ckpt-3"))
self.assertEqual([s1, s2, s3], save.last_checkpoints)
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s1, s2, s3],
save_dir=save_dir)
# Create another saver and recover last checkpoints.
save2 = saver_module.Saver({"v": v}, max_to_keep=10)
self.assertEqual([], save2.last_checkpoints)
save2.recover_last_checkpoints([s1, s2, s3])
self.assertEqual([s1, s2, s3], save2.last_checkpoints)
# Remove a checkpoint and check that last checkpoints are
# restored correctly.
for fname in gfile.Glob("{}*".format(s1)):
gfile.Remove(fname)
self.assertFalse(checkpoint_management.checkpoint_exists(s1))
# Create another saver and recover last checkpoints. The removed
# checkpoint would be correctly omitted.
save3 = saver_module.Saver({"v": v}, max_to_keep=10)
self.assertEqual([], save3.last_checkpoints)
save3.recover_last_checkpoints([s1, s2, s3])
self.assertEqual([s2, s3], save3.last_checkpoints)
s4 = save3.save(None, os.path.join(save_dir, "ckpt-4"))
self.assertCheckpointState(
model_checkpoint_path=s4,
all_model_checkpoint_paths=[s2, s3, s4],
save_dir=save_dir)
class KeepCheckpointEveryNHoursTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
@test_util.run_in_graph_and_eager_modes
@test.mock.patch.object(saver_module, "time")
def testNonSharded(self, mock_time):
save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
with self.cached_session() as sess:
v = variable_scope.variable([10.0], name="v")
# Run the initializer NOW to avoid the 0.5s overhead of the first Run()
# call, which throws the test timing off in fastbuild mode.
self.evaluate(variables.global_variables_initializer())
# Create a saver that will keep the last 2 checkpoints plus one every 0.7
# seconds.
start_time = time.time()
mock_time.time.return_value = start_time
save = saver_module.Saver(
{
"v": v
}, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600)
self.assertEqual([], save.last_checkpoints)
# Wait till 1 seconds have elapsed so s1 will be old enough to keep.
# sleep may return early, don't trust it.
mock_time.time.return_value = start_time + 1.0
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
# We now have 2 'last_checkpoints': [s1, s2]. The next call to Save(),
# would normally delete s1, because max_to_keep is 2. However, s1 is
# older than 0.7s so we must keep it.
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
# s1 should still be here, we are Not checking now to reduce time
# variance in the test.
# We now have 2 'last_checkpoints': [s2, s3], and s1 on disk. The next
# call to Save(), will delete s2, because max_to_keep is 2, and because
# we already kept the old s1. s2 is very close in time to s1 so it gets
# deleted.
s4 = save.save(sess, os.path.join(save_dir, "s4"))
self.assertEqual([s3, s4], save.last_checkpoints)
# Check that s1 is still here, but s2 is gone.
self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertFalse(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(checkpoint_management.checkpoint_exists(s4))
class SaveRestoreWithVariableNameMap(test.TestCase):
def _testNonReshape(self, variable_op):
save_path = os.path.join(self.get_temp_dir(), "non_reshape")
with self.session(graph=ops_lib.Graph()) as sess:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
v0 = variable_op(10.0, name="v0")
v1 = variable_op(20.0, name="v1")
save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
self.evaluate(variables.global_variables_initializer())
# Check that the parameter nodes have been initialized.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
# Save the initialized values in the file at "save_path"
# Use a variable name map to set the saved tensor names
val = save.save(sess, save_path)
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path, val)
# Verify that the original names are not in the Saved file
save = saver_module.Saver({"v0": v0, "v1": v1})
with self.assertRaisesOpError("not found in checkpoint"):
save.restore(sess, save_path)
# Verify that the mapped names are present in the Saved file and can be
# Restored using remapped names.
with self.session(graph=ops_lib.Graph()) as sess:
v0 = variable_op(-1.0, name="v0")
v1 = variable_op(-1.0, name="v1")
if not context.executing_eagerly():
with self.assertRaisesOpError("uninitialized"):
self.evaluate(v0)
with self.assertRaisesOpError("uninitialized"):
self.evaluate(v1)
save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.
if not context.executing_eagerly():
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
# Add a prefix to the node names in the current graph and Restore using
# remapped names.
with self.session(graph=ops_lib.Graph()) as sess:
v0 = variable_op(-1.0, name="restore_prefix/v0")
v1 = variable_op(-1.0, name="restore_prefix/v1")
if not context.executing_eagerly():
with self.assertRaisesOpError("uninitialized"):
self.evaluate(v0)
with self.assertRaisesOpError("uninitialized"):
self.evaluate(v1)
# Restore the saved values in the parameter nodes.
save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
save.restore(sess, save_path)
# Check that the parameter nodes have been restored.
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
@test_util.run_in_graph_and_eager_modes
def testNonReshapeResourceVariable(self):
self._testNonReshape(resource_variable_ops.ResourceVariable)
def testNonReshapeVariable(self):
self._testNonReshape(variables.Variable)
class MetaGraphTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
@test_util.run_v1_only(
"Queue-based input pipelines have been replaced by `tf.data` "
"and not supported in V2.")
def testAddCollectionDef(self):
test_dir = self._get_test_dir("good_collection")
filename = os.path.join(test_dir, "metafile")
with self.cached_session():
# Creates a graph.
v0 = variables.VariableV1(1.0, name="v0")
control_flow_ops.cond(
math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),
lambda: math_ops.subtract(v0, 1))
control_flow_ops.while_loop(lambda i: math_ops.less(i, 10),
lambda i: math_ops.add(i, 1), [v0])
var = variables.VariableV1(constant_op.constant(0, dtype=dtypes.int64))
count_up_to = var.count_up_to(3)
input_queue = data_flow_ops.FIFOQueue(
30, dtypes.float32, shared_name="collection_queue")
qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to])
variables.global_variables_initializer()
# Creates a saver.
save = saver_module.Saver({"v0": v0})
# Adds a set of collections.
ops_lib.add_to_collection("int_collection", 3)
ops_lib.add_to_collection("float_collection", 3.5)
ops_lib.add_to_collection("string_collection", "hello")
ops_lib.add_to_collection("variable_collection", v0)
# Add QueueRunners.
queue_runner_impl.add_queue_runner(qr)
# Adds user_defined proto in three formats: string, bytes and Any.
queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
ops_lib.add_to_collection("user_defined_string_collection",
str(queue_runner))
ops_lib.add_to_collection("user_defined_bytes_collection",
queue_runner.SerializeToString())
any_buf = Any()
any_buf.Pack(queue_runner)
ops_lib.add_to_collection("user_defined_any_collection", any_buf)
# Generates MetaGraphDef.
meta_graph_def = save.export_meta_graph(filename)
self.assertTrue(meta_graph_def.HasField("saver_def"))
self.assertTrue(meta_graph_def.HasField("graph_def"))
self.assertTrue(meta_graph_def.HasField("meta_info_def"))
self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
"")
collection_def = meta_graph_def.collection_def
self.assertEqual(len(collection_def), 12)
with ops_lib.Graph().as_default():
# Restores from MetaGraphDef.
new_saver = saver_module.import_meta_graph(filename)
# Generates a new MetaGraphDef.
new_meta_graph_def = new_saver.export_meta_graph()
# It should be the same as the original.
test_util.assert_meta_graph_protos_equal(
self, meta_graph_def, new_meta_graph_def)
def testAddCollectionDefFails(self):
with self.cached_session():
# Creates a graph.
v0 = variables.VariableV1(10.0, name="v0")
# Creates a saver.
save = saver_module.Saver({"v0": v0})
# Generates MetaGraphDef.
meta_graph_def = meta_graph_pb2.MetaGraphDef()
# Verifies that collection with unsupported key will not be added.
ops_lib.add_to_collection(save, 3)
save._add_collection_def(meta_graph_def, save)
self.assertEqual(len(meta_graph_def.collection_def), 0)
# Verifies that collection where item type does not match expected
# type will not be added.
ops_lib.add_to_collection("int_collection", 3)
ops_lib.add_to_collection("int_collection", 3.5)
save._add_collection_def(meta_graph_def, "int_collection")
self.assertEqual(len(meta_graph_def.collection_def), 0)
def _testMultiSaverCollectionSave(self, test_dir):
filename = os.path.join(test_dir, "metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
v1 = variables.VariableV1(11.0, name="v1")
# Creates 2 savers.
saver0 = saver_module.Saver({"v0": v0}, name="saver0")
saver1 = saver_module.Saver({"v1": v1}, name="saver1")
ops_lib.add_to_collection("savers", saver0)
ops_lib.add_to_collection("savers", saver1)
self.evaluate(variables.global_variables_initializer())
# Saves to different checkpoints.
saver0.save(sess, saver0_ckpt)
saver1.save(sess, saver1_ckpt)
# Generates MetaGraphDef.
meta_graph_def = saver_module.export_meta_graph(filename)
meta_graph_def0 = saver0.export_meta_graph()
meta_graph_def1 = saver1.export_meta_graph()
# Verifies that there is no saver_def in meta_graph_def.
self.assertFalse(meta_graph_def.HasField("saver_def"))
# Verifies that there is saver_def in meta_graph_def0 and 1.
self.assertTrue(meta_graph_def0.HasField("saver_def"))
self.assertTrue(meta_graph_def1.HasField("saver_def"))
# Verifies SAVERS is saved as bytes_list for meta_graph_def.
collection_def = meta_graph_def.collection_def["savers"]
kind = collection_def.WhichOneof("kind")
self.assertEqual(kind, "bytes_list")
# Verifies that there are 2 entries in SAVERS collection.
savers = getattr(collection_def, kind)
self.assertEqual(2, len(savers.value))
# Verifies SAVERS collection is saved as bytes_list for meta_graph_def0.
collection_def = meta_graph_def0.collection_def["savers"]
kind = collection_def.WhichOneof("kind")
self.assertEqual(kind, "bytes_list")
# Verifies that there are 2 entries in SAVERS collection.
savers = getattr(collection_def, kind)
self.assertEqual(2, len(savers.value))
def _testMultiSaverCollectionRestore(self, test_dir):
filename = os.path.join(test_dir, "metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Imports from meta_graph.
saver_module.import_meta_graph(filename)
# Retrieves SAVERS collection. Verifies there are 2 entries.
savers = ops_lib.get_collection("savers")
self.assertEqual(2, len(savers))
# Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1.
new_saver0 = savers[0]
new_saver0.restore(sess, saver0_ckpt)
v0 = sess.graph.get_tensor_by_name("v0:0")
v1 = sess.graph.get_tensor_by_name("v1:0")
self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
self.evaluate(v0))
self.assertEqual([3, 2], v0.get_shape())
self.assertEqual([], v1.get_shape())
with self.assertRaisesWithPredicateMatch(
errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
self.evaluate(v1)
# Retrieves saver1. Verifies that new_saver1 can restore v1.
new_saver1 = savers[1]
new_saver1.restore(sess, saver1_ckpt)
v1 = sess.graph.get_tensor_by_name("v1:0")
self.assertEqual(11.0, self.evaluate(v1))
@test_util.run_v1_only(
"Exporting/importing meta graphs is only supported in V1.")
def testMultiSaverCollection(self):
test_dir = self._get_test_dir("saver_collection")
self._testMultiSaverCollectionSave(test_dir)
self._testMultiSaverCollectionRestore(test_dir)
@test_util.run_v1_only(
"Exporting/importing meta graphs is only supported in V1.")
def testClearExtraneousSavers(self):
test_dir = self._get_test_dir("clear_extraneous_savers")
filename = os.path.join(test_dir, "metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
v1 = variables.VariableV1(11.0, name="v1")
# Creates 2 savers.
saver0 = saver_module.Saver({"v0": v0}, name="saver0")
saver1 = saver_module.Saver({"v1": v1}, name="saver1")
ops_lib.add_to_collection("savers", saver0)
ops_lib.add_to_collection("savers", saver1)
self.evaluate(variables.global_variables_initializer())
# Saves to different checkpoints.
saver0.save(sess, saver0_ckpt)
saver1.save(sess, saver1_ckpt)
# Generates MetaGraphDef.
meta_graph_def = saver_module.export_meta_graph(filename)
meta_graph_def0 = saver0.export_meta_graph()
meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True)
# Verifies that there is no saver_def in meta_graph_def.
self.assertFalse(meta_graph_def.HasField("saver_def"))
# Verifies that there is saver_def in meta_graph_def0 and 1.
self.assertTrue(meta_graph_def0.HasField("saver_def"))
self.assertTrue(meta_graph_def1.HasField("saver_def"))
# Verifies SAVERS is saved as bytes_list for meta_graph_def.
collection_def = meta_graph_def.collection_def["savers"]
kind = collection_def.WhichOneof("kind")
self.assertEqual(kind, "bytes_list")
# Verifies that there are 2 entries in SAVERS collection.
savers = getattr(collection_def, kind)
self.assertEqual(2, len(savers.value))
# Verifies SAVERS collection is saved as bytes_list for meta_graph_def1.
collection_def = meta_graph_def1.collection_def["savers"]
kind = collection_def.WhichOneof("kind")
self.assertEqual(kind, "bytes_list")
# Verifies that there is 1 entry in SAVERS collection.
savers = getattr(collection_def, kind)
self.assertEqual(1, len(savers.value))
# Verifies that saver0 graph nodes are omitted from the saver1 export
self.assertEqual(33, len(meta_graph_def0.graph_def.node))
self.assertEqual(21, len(meta_graph_def1.graph_def.node))
def testBinaryAndTextFormat(self):
test_dir = self._get_test_dir("binary_and_text")
filename = os.path.join(test_dir, "metafile")
# train.Saver is V1 only API.
with ops_lib.Graph().as_default(), self.session():
# Creates a graph.
variables.VariableV1(10.0, name="v0")
# Exports the graph as binary format.
saver_module.export_meta_graph(filename, as_text=False)
with ops_lib.Graph().as_default(), self.session():
# Imports the binary format graph.
saver = saver_module.import_meta_graph(filename)
self.assertIsNotNone(saver)
# Exports the graph as text format.
saver.export_meta_graph(filename, as_text=True)
with ops_lib.Graph().as_default(), self.session():
# Imports the text format graph.
saver_module.import_meta_graph(filename)
# Writes wrong contents to the file.
graph_io.write_graph(saver.as_saver_def(),
os.path.dirname(filename),
os.path.basename(filename))
with ops_lib.Graph().as_default(), self.session():
# Import should fail.
with self.assertRaisesWithPredicateMatch(IOError,
lambda e: "Cannot parse file"):
saver_module.import_meta_graph(filename)
# Deletes the file
gfile.Remove(filename)
with self.assertRaisesWithPredicateMatch(IOError,
lambda e: "does not exist"):
saver_module.import_meta_graph(filename)
@test_util.run_v1_only(
"Exporting/importing meta graphs is only supported in V1.")
def testSliceVariable(self):
test_dir = self._get_test_dir("slice_saver")
filename = os.path.join(test_dir, "metafile")
with self.cached_session():
v1 = variables.VariableV1([20.0], name="v1")
v2 = variables.VariableV1([20.0], name="v2")
v2._set_save_slice_info(
variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
# The names are different and will work.
slice_saver = saver_module.Saver({"first": v1, "second": v2})
self.evaluate(variables.global_variables_initializer())
# Exports to meta_graph
meta_graph_def = slice_saver.export_meta_graph(filename)
with ops_lib.Graph().as_default():
# Restores from MetaGraphDef.
new_saver = saver_module.import_meta_graph(filename)
self.assertIsNotNone(new_saver)
# Generates a new MetaGraphDef.
new_meta_graph_def = new_saver.export_meta_graph()
# It should be the same as the original.
test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
new_meta_graph_def)
def _testGraphExtensionSave(self, test_dir):
filename = os.path.join(test_dir, "metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
# Creates an inference graph.
# Hidden 1
images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])
with ops_lib.name_scope("hidden1"):
weights = variables.VariableV1(
random_ops.truncated_normal(
[28, 128], stddev=1.0 / math.sqrt(float(28))),
name="weights")
# The use of control_flow_ops.cond here is purely for adding test coverage
# the save and restore of control flow context (which doesn't make any
# sense here from a machine learning perspective). The typical biases is
# a simple Variable without the conditions.
biases = variables.VariableV1(
control_flow_ops.cond(
math_ops.less(random.random(), 0.5),
lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
name="biases")
hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)
# Hidden 2
with ops_lib.name_scope("hidden2"):
weights = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
# The use of control_flow_ops.while_loop here is purely for adding test
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
def loop_cond(it, _):
return it < 2
def loop_body(it, biases):
biases += constant_op.constant(0.1, shape=[32])
return it + 1, biases
_, biases = control_flow_ops.while_loop(
loop_cond, loop_body,
[constant_op.constant(0),
variables.VariableV1(array_ops.zeros([32]))])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
# Linear
with ops_lib.name_scope("softmax_linear"):
weights = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights) + biases
ops_lib.add_to_collection("logits", logits)
init_all_op = variables.global_variables_initializer()
with self.cached_session() as sess:
# Initializes all the variables.
self.evaluate(init_all_op)
# Runs to logit.
self.evaluate(logits)
# Creates a saver.
saver0 = saver_module.Saver()
saver0.save(sess, saver0_ckpt)
# Generates MetaGraphDef.
saver0.export_meta_graph(filename)
def _testGraphExtensionRestore(self, test_dir):
filename = os.path.join(test_dir, "metafile")
train_filename = os.path.join(test_dir, "train_metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Restores from MetaGraphDef.
new_saver = saver_module.import_meta_graph(filename)
# Generates a new MetaGraphDef.
new_saver.export_meta_graph()
# Restores from checkpoint.
new_saver.restore(sess, saver0_ckpt)
# Adds loss and train.
labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels")
batch_size = array_ops.size(labels)
labels = array_ops.expand_dims(labels, 1)
indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1)
concated = array_ops.concat([indices, labels], 1)
onehot_labels = sparse_ops.sparse_to_dense(
concated, array_ops.stack([batch_size, 10]), 1.0, 0.0)
logits = ops_lib.get_collection("logits")[0]
cross_entropy = nn_ops.softmax_cross_entropy_with_logits(
labels=onehot_labels, logits=logits, name="xentropy")
loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")
summary.scalar("loss", loss)
# Creates the gradient descent optimizer with the given learning rate.
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
# Runs train_op.
train_op = optimizer.minimize(loss)
ops_lib.add_to_collection("train_op", train_op)
# Runs train_op.
self.evaluate(train_op)
# Generates MetaGraphDef.
saver_module.export_meta_graph(train_filename)
def _testRestoreFromTrainGraphWithControlContext(self, test_dir):
train_filename = os.path.join(test_dir, "train_metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Restores from MetaGraphDef.
new_saver = saver_module.import_meta_graph(train_filename)
# Restores from checkpoint.
new_saver.restore(sess, saver0_ckpt)
train_op = ops_lib.get_collection("train_op")[0]
self.evaluate(train_op)
def testGraphExtension(self):
test_dir = self._get_test_dir("graph_extension")
# train.Saver and train.import_meta_graph are V1 only APIs.
with ops_lib.Graph().as_default():
self._testGraphExtensionSave(test_dir)
self._testGraphExtensionRestore(test_dir)
self._testRestoreFromTrainGraphWithControlContext(test_dir)
def _testGradientSerDes(self, graph_fn):
"""Tests that gradients can be computed after exporting and importing.
Builds a graph, exports it, and verifies that it can be imported and the
gradient can be built and run correctly.
Args:
graph_fn: takes a single float Tensor argument as input, outputs a single
Tensor
"""
test_dir = self._get_test_dir("nested_control_flow")
filename = os.path.join(test_dir, "metafile")
saver_ckpt = os.path.join(test_dir, "saver.ckpt")
# Create while loop using `outer_body_fn`.
with ops_lib.Graph().as_default():
var = variables.VariableV1(0.0)
var_name = var.name
output = graph_fn(var)
output_name = output.name
init_op = variables.global_variables_initializer()
# Generate a MetaGraphDef containing the while loop.
with session.Session() as sess:
self.evaluate(init_op)
self.evaluate(output)
saver = saver_module.Saver()
saver.save(sess, saver_ckpt)
saver.export_meta_graph(filename)
# Build and run the gradients of the while loop. We use this below to
# verify that the gradients are correct with an imported MetaGraphDef.
grad = gradients_impl.gradients([output], [var])
# Turn off constant folding to avoid breaking testNestedControlFlowSerDes.
# It appears that a missing control dependency in the gradient graph
# causes the fetch node to not be triggered.
no_constfold_config = config_pb2.ConfigProto()
no_constfold_config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
with session.Session(config=no_constfold_config) as sess:
self.evaluate(init_op)
expected_grad_value = self.evaluate(grad)
# Restore the MetaGraphDef into a new Graph.
with ops_lib.Graph().as_default():
with session.Session() as sess:
saver = saver_module.import_meta_graph(filename)
saver.restore(sess, saver_ckpt)
# Make sure we can still build gradients and get the same result.
var = ops_lib.get_default_graph().get_tensor_by_name(var_name)
output = ops_lib.get_default_graph().get_tensor_by_name(output_name)
grad = gradients_impl.gradients([output], [var])
init_op = variables.global_variables_initializer()
with session.Session(config=no_constfold_config) as sess:
self.evaluate(init_op)
actual_grad_value = self.evaluate(grad)
self.assertEqual(expected_grad_value, actual_grad_value)
def _testWhileLoopAndGradientSerDes(self, outer_body_fn):
# Build a while loop with `outer_body_fn`, export it, and verify that it can
# be imported and the gradient can be built and run correctly.
# pylint: disable=g-long-lambda
return self._testGradientSerDes(
lambda x: control_flow_ops.while_loop(
lambda i, y: i < 5, outer_body_fn, [0, x])[1])
# pylint: enable=g-long-lambda
def testNestedWhileLoopsSerDes(self):
# Test two simple nested while loops.
def body(i, x):
_, r = control_flow_ops.while_loop(lambda j, y: j < 3,
lambda j, y: (j + 1, y + x),
[0, 0.0])
return i + 1, x + r
self._testWhileLoopAndGradientSerDes(body)
def testNestedControlFlowSerDes(self):
# Test while loop in a cond in a while loop.
# pylint: disable=g-long-lambda
def body(i, x):
cond_result = control_flow_ops.cond(
i > 0,
lambda: control_flow_ops.while_loop(
lambda j, y: j < 3,
lambda j, y: (j + 1, y + x),
[0, 0.0])[1],
lambda: x)
return i + 1, cond_result
# pylint: enable=g-long-lambda
self._testWhileLoopAndGradientSerDes(body)
def testNestedCondsSerDes(self):
# Test conds in a cond.
# pylint: disable=g-long-lambda
self._testGradientSerDes(lambda x: control_flow_ops.cond(
x > 0,
lambda: control_flow_ops.cond(x > 3,
lambda: array_ops.identity(x),
lambda: math_ops.multiply(x, 2.0)),
lambda: control_flow_ops.cond(x < -3,
lambda: constant_op.constant(1.0),
lambda: math_ops.multiply(x, -1.0))))
# pylint: enable=g-long-lambda
@test_util.run_v1_only("This exercises Tensor.op which is meaningless in V2.")
def testStrippedOpListDef(self):
with self.cached_session():
# Creates a graph.
v0 = variables.VariableV1(0.0)
var = variables.VariableV1(10.0)
math_ops.add(v0, var)
@function.Defun(dtypes.float32)
def minus_one(x):
return x - 1
minus_one(array_ops.identity(v0))
save = saver_module.Saver({"v0": v0})
variables.global_variables_initializer()
# Generates MetaGraphDef.
meta_graph_def = save.export_meta_graph()
ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(ops, [
"Add", "Assign", "Const", "Identity", "NoOp",
"PlaceholderWithDefault", "RestoreV2", "SaveSlices", "Sub",
"VariableV2"
])
else:
self.assertEqual(ops, [
"Add", "Assign", "Const", "Identity", "NoOp",
"PlaceholderWithDefault", "RestoreV2", "SaveV2", "Sub", "VariableV2"
])
# Test calling stripped_op_list_for_graph directly
op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)
self.assertEqual(ops, [o.name for o in op_list.op])
for o in op_list.op:
self.assertEqual(o.summary, "")
self.assertEqual(o.description, "")
def testStripDefaultValuedAttrs(self):
"""Verifies that default valued attrs are stripped, unless disabled."""
# With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
# (complex64) in the "Complex" op must be removed.
# train.Saver and train.export_meta_graph are V1 only APIs.
with ops_lib.Graph().as_default(), self.cached_session():
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
variables.global_variables_initializer()
meta_graph_def = save.export_meta_graph(strip_default_attrs=True)
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertNotIn("T", node_def.attr)
self.assertNotIn("Tout", node_def.attr)
# With strip_default_attrs disabled, attributes "T" (float32) and "Tout"
# (complex64) in the "Complex" op must *not* be removed, even if they map
# to their defaults.
with ops_lib.Graph().as_default(), self.session():
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
variables.global_variables_initializer()
meta_graph_def = save.export_meta_graph(strip_default_attrs=False)
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertIn("T", node_def.attr)
self.assertIn("Tout", node_def.attr)
def testImportIntoNamescope(self):
# Test that we can import a meta graph into a namescope.
test_dir = self._get_test_dir("import_into_namescope")
filename = os.path.join(test_dir, "ckpt")
# train.Saver is V1 only API.
with ops_lib.Graph().as_default():
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
with session.Session() as sess:
weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(
math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(
labels=label, logits=logit, name="cost")
adam.AdamOptimizer().minimize(cost, name="optimize")
saver = saver_module.Saver()
self.evaluate(variables.global_variables_initializer())
saver.save(sess, filename)
graph = ops_lib.Graph()
with session.Session(graph=graph) as sess:
new_saver = saver_module.import_meta_graph(
filename + ".meta", graph=graph, import_scope="new_model")
new_saver.restore(sess, filename)
sess.run(["new_model/optimize"], {
"new_model/image:0": np.random.random([1, 784]),
"new_model/label:0": np.random.randint(
10, size=[1, 10])
})
def testImportIntoNamescopeWithoutVariables(self):
# Save a simple graph that contains no variables into a checkpoint.
test_dir = self._get_test_dir("no_vars_graph")
filename = os.path.join(test_dir, "ckpt")
graph_1 = ops_lib.Graph()
with session.Session(graph=graph_1) as sess:
constant_op.constant([1, 2, 3], name="x")
constant_op.constant([1, 2, 3], name="y")
saver = saver_module.Saver(allow_empty=True)
saver.save(sess, filename)
# Create a fresh graph.
graph_2 = ops_lib.Graph()
with session.Session(graph=graph_2) as sess:
# Restore the above checkpoint under scope "subgraph_1".
new_saver_1 = saver_module.import_meta_graph(
filename + ".meta", graph=graph_2, import_scope="subgraph_1")
# There are no variables to restore, so import_meta_graph should not
# return a Saver.
self.assertIsNone(new_saver_1)
# Create a variable in graph_2 under scope "my_scope".
variables.VariableV1(array_ops.zeros([10]), name="my_scope/my_var")
self.evaluate(variables.global_variables_initializer())
# Restore the checkpoint into a different scope "subgraph_2".
new_saver_2 = saver_module.import_meta_graph(
filename + ".meta", graph=graph_2, import_scope="subgraph_2")
# Because the variable does not live in scope "subgraph_2",
# import_meta_graph should not attempt to restore the variable. So,
# import_meta_graph still won't return a Saver instance.
self.assertIsNone(new_saver_2)
# However, if we restore the checkpoint under scope "my_scope",
# import_meta_graph will detect the variable and return a Saver for
# restoring it. This should happen even when the variable does not
# originate from graph_1.
new_saver_3 = saver_module.import_meta_graph(
filename + ".meta", graph=graph_2, import_scope="my_scope")
self.assertIsInstance(new_saver_3, saver_module.Saver)
def testImportIntoImplicitNamescope(self):
# Test that we can import a meta graph into an implicit namescope.
test_dir = self._get_test_dir("import_into_namescope")
filename = os.path.join(test_dir, "ckpt")
# train.Saver is V1 only API.
with ops_lib.Graph().as_default():
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
with session.Session() as sess:
weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(
math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(
labels=label, logits=logit, name="cost")
adam.AdamOptimizer().minimize(cost, name="optimize")
saver = saver_module.Saver()
self.evaluate(variables.global_variables_initializer())
saver.save(sess, filename)
graph = ops_lib.Graph()
with session.Session(graph=graph) as sess:
with ops_lib.name_scope("new_model"):
new_saver = saver_module.import_meta_graph(
filename + ".meta", graph=graph)
new_saver.restore(sess, filename)
sess.run(["new_model/optimize"], {
"new_model/image:0": np.random.random([1, 784]),
"new_model/label:0": np.random.randint(
10, size=[1, 10])
})
def testClearDevicesOnImport(self):
# Test that we import a graph without its devices and run successfully.
with ops_lib.Graph().as_default():
with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
logits=logit)
adam.AdamOptimizer().minimize(cost, name="optimize")
meta_graph_def = saver_module.export_meta_graph()
with session.Session(graph=ops_lib.Graph()) as sess:
saver_module.import_meta_graph(
meta_graph_def, clear_devices=False, import_scope="new_model")
# Device refers to GPU, which is not available here.
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(variables.global_variables_initializer())
with session.Session(graph=ops_lib.Graph()) as sess:
saver_module.import_meta_graph(
meta_graph_def, clear_devices=True, import_scope="new_model")
self.evaluate(variables.global_variables_initializer())
sess.run(["new_model/optimize"], {
"new_model/image:0": np.random.random([1, 784]),
"new_model/label:0": np.random.randint(
10, size=[1, 10])
})
def testClearDevicesOnExport(self):
# Test that we export a graph without its devices and run successfully.
with ops_lib.Graph().as_default():
with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
logits=logit)
adam.AdamOptimizer().minimize(cost, name="optimize")
meta_graph_def = saver_module.export_meta_graph(clear_devices=True)
graph_io.write_graph(meta_graph_def, self.get_temp_dir(),
"meta_graph.pbtxt")
with session.Session(graph=ops_lib.Graph()) as sess:
saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
self.evaluate(variables.global_variables_initializer())
sess.run(["new_model/optimize"], {
"new_model/image:0": np.random.random([1, 784]),
"new_model/label:0": np.random.randint(
10, size=[1, 10])
})
def testPreserveDatasetAndFunctions(self):
with ops_lib.Graph().as_default() as g:
dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x)
iterator = dataset_ops.make_one_shot_iterator(dataset)
next_element = iterator.get_next()
_ = array_ops.identity(next_element, name="output")
# Generate three MetaGraphDef protos using different code paths.
meta_graph_def_simple = saver_module.export_meta_graph()
meta_graph_def_devices_cleared = saver_module.export_meta_graph(
clear_devices=True)
meta_graph_def_from_graph_def = saver_module.export_meta_graph(
clear_devices=True, graph_def=g.as_graph_def())
for meta_graph_def in [meta_graph_def_simple,
meta_graph_def_devices_cleared,
meta_graph_def_from_graph_def]:
with session.Session(graph=ops_lib.Graph()) as sess:
saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
self.evaluate(variables.global_variables_initializer())
for i in range(10):
self.assertEqual(i * i, sess.run("new_model/output:0"))
with self.assertRaises(errors.OutOfRangeError):
sess.run("new_model/output:0")
class CheckpointReaderTest(test.TestCase):
_WRITE_VERSION = saver_pb2.SaverDef.V1
def testDebugString(self):
# Builds a graph.
v0 = variables.VariableV1(
[[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
v1 = variables.VariableV1(
[[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1")
init_all_op = variables.global_variables_initializer()
save = saver_module.Saver(
{
"v0": v0,
"v1": v1
}, write_version=self._WRITE_VERSION)
save_path = os.path.join(self.get_temp_dir(),
"ckpt_for_debug_string" + str(self._WRITE_VERSION))
with self.cached_session() as sess:
self.evaluate(init_all_op)
# Saves a checkpoint.
save.save(sess, save_path)
# Creates a reader.
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
# Verifies that the tensors exist.
self.assertTrue(reader.has_tensor("v0"))
self.assertTrue(reader.has_tensor("v1"))
debug_string = reader.debug_string()
# Verifies that debug string contains the right strings.
self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string)
self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string)
# Verifies get_variable_to_shape_map() returns the correct information.
var_map = reader.get_variable_to_shape_map()
self.assertEqual([2, 3], var_map["v0"])
self.assertEqual([3, 2, 1], var_map["v1"])
# Verifies get_tensor() returns the tensor value.
v0_tensor = reader.get_tensor("v0")
v1_tensor = reader.get_tensor("v1")
self.assertAllEqual(v0, v0_tensor)
self.assertAllEqual(v1, v1_tensor)
# Verifies get_tensor() fails for non-existent tensors.
with self.assertRaisesRegex(errors.NotFoundError,
"v3 not found in checkpoint"):
reader.get_tensor("v3")
def testNonexistentPath(self):
with self.assertRaisesRegex(errors.NotFoundError,
"Unsuccessful TensorSliceReader"):
py_checkpoint_reader.NewCheckpointReader("non-existent")
class CheckpointReaderForV2Test(CheckpointReaderTest):
_WRITE_VERSION = saver_pb2.SaverDef.V2
class WriteGraphTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
def testWriteGraph(self):
test_dir = self._get_test_dir("write_graph_dir")
variables.VariableV1(
[[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph(),
os.path.join(test_dir, "l1"), "graph.pbtxt")
truth = os.path.join(test_dir, "l1", "graph.pbtxt")
self.assertEqual(path, truth)
self.assertTrue(os.path.exists(path))
def testRecursiveCreate(self):
test_dir = self._get_test_dir("deep_dir")
variables.VariableV1(
[[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
os.path.join(test_dir, "l1", "l2", "l3"),
"graph.pbtxt")
truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt")
self.assertEqual(path, truth)
self.assertTrue(os.path.exists(path))
class ScopedGraphTest(test.TestCase):
def _get_test_dir(self, dirname):
test_dir = os.path.join(self.get_temp_dir(), dirname)
gfile.MakeDirs(test_dir)
return test_dir
def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
graph = ops_lib.Graph()
with graph.as_default():
# Creates an inference graph.
# Hidden 1
images = constant_op.constant(
1.2, dtypes.float32, shape=[100, 28], name="images")
with ops_lib.name_scope("hidden1"):
weights1 = variables.VariableV1(
random_ops.truncated_normal(
[28, 128], stddev=1.0 / math.sqrt(float(28))),
name="weights")
# The use of control_flow_ops.cond here is purely for adding test
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
biases1 = variables.VariableV1(
control_flow_ops.cond(
math_ops.less(random.random(), 0.5),
lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
name="biases")
hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)
# Hidden 2
with ops_lib.name_scope("hidden2"):
weights2 = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
# The use of control_flow_ops.while_loop here is purely for adding test
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
def loop_cond(it, _):
return it < 2
def loop_body(it, biases2):
biases2 += constant_op.constant(0.1, shape=[32])
return it + 1, biases2
_, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [
constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
# Linear
with ops_lib.name_scope("softmax_linear"):
weights3 = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
biases3 = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights3) + biases3
ops_lib.add_to_collection("logits", logits)
# Adds user_defined proto in three formats: string, bytes and Any.
# Any proto should just pass through.
queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
ops_lib.add_to_collection("user_defined_string_collection",
str(queue_runner))
ops_lib.add_to_collection("user_defined_bytes_collection",
queue_runner.SerializeToString())
any_buf = Any()
any_buf.Pack(queue_runner)
ops_lib.add_to_collection("user_defined_any_collection", any_buf)
_, var_list = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filename),
graph=ops_lib.get_default_graph(),
export_scope="hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
with graph.as_default(), self.session() as sess:
self.evaluate(variables.global_variables_initializer())
saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False)
def _testScopedRestore(self, test_dir, exported_filename,
new_exported_filename, ckpt_filename):
graph = ops_lib.Graph()
# Create all the missing inputs.
with graph.as_default():
new_image = constant_op.constant(
1.2, dtypes.float32, shape=[100, 28], name="images")
var_list = meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filename),
graph=graph,
input_map={"$unbound_inputs_images": new_image},
import_scope="new_hidden1")
self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
hidden1 = graph.as_graph_element("new_hidden1/Relu:0")
weights1 = graph.as_graph_element("new_hidden1/weights:0")
biases1 = graph.as_graph_element("new_hidden1/biases:0")
with graph.as_default():
# Hidden 2
with ops_lib.name_scope("hidden2"):
weights = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
# The use of control_flow_ops.while_loop here is purely for adding test
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
def loop_cond(it, _):
return it < 2
def loop_body(it, biases):
biases += constant_op.constant(0.1, shape=[32])
return it + 1, biases
_, biases = control_flow_ops.while_loop(loop_cond, loop_body, [
constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
# Linear
with ops_lib.name_scope("softmax_linear"):
weights = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights) + biases
ops_lib.add_to_collection("logits", logits)
# The rest of the variables.
rest_variables = list(
set(variables.global_variables()) - set(var_list.keys()))
init_rest_op = variables.variables_initializer(rest_variables)
with graph.as_default(), self.session() as sess:
saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
saver.restore(sess, os.path.join(test_dir, ckpt_filename))
# Verify that we have restored weights1 and biases1.
self.evaluate([weights1, biases1])
# Initialize the rest of the variables and run logits.
self.evaluate(init_rest_op)
self.evaluate(logits)
# Verifies that we can save the subgraph under "hidden1" and restore it
# into "new_hidden1" in the new graph.
def testScopedSaveAndRestore(self):
test_dir = self._get_test_dir("scoped_export_import")
ckpt_filename = "ckpt"
self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
"exported_new_hidden1.pbtxt", ckpt_filename)
# Verifies that we can copy the subgraph under "hidden1" and copy it
# to different name scope in the same graph or different graph.
def testCopyScopedGraph(self):
test_dir = self._get_test_dir("scoped_copy")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
graph1 = ops_lib.Graph()
with graph1.as_default():
with ops_lib.name_scope("hidden1"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
weights1 = variables.VariableV1(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
biases1 = variables.VariableV1([0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
with graph1.as_default(), self.session(graph=graph1) as sess:
self.evaluate(variables.global_variables_initializer())
_, var_list_1 = meta_graph.export_scoped_meta_graph(
export_scope="hidden1")
saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
saver.save(sess, saver0_ckpt, write_state=False)
expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))
# Verifies copy to the same graph with the same name fails.
with graph1.as_default():
with self.assertRaisesWithPredicateMatch(
ValueError, lambda e: "need to be different" in str(e)):
meta_graph.copy_scoped_meta_graph(
from_scope="hidden1", to_scope="hidden1")
# Verifies copy to the same graph.
with graph1.as_default():
var_list_2 = meta_graph.copy_scoped_meta_graph(
from_scope="hidden1", to_scope="hidden2")
with graph1.as_default(), self.session(graph=graph1) as sess:
saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
saver1.restore(sess, saver0_ckpt)
saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1)
saver2.restore(sess, saver0_ckpt)
self.assertAllClose(expected, sess.run("hidden1/relu:0"))
self.assertAllClose(expected, sess.run("hidden2/relu:0"))
# Verifies copy to different graph.
graph2 = ops_lib.Graph()
with graph2.as_default():
new_var_list_1 = meta_graph.copy_scoped_meta_graph(
from_scope="hidden1",
to_scope="new_hidden1",
from_graph=graph1,
to_graph=graph2)
with self.session() as sess:
saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
saver3.restore(sess, saver0_ckpt)
self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
def testExportGraphDefWithScope(self):
test_dir = self._get_test_dir("export_graph_def")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
graph1 = ops_lib.Graph()
with graph1.as_default():
with ops_lib.name_scope("hidden1"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
weights1 = variables.VariableV1(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
biases1 = variables.VariableV1([0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
with self.session(graph=graph1) as sess:
self.evaluate(variables.global_variables_initializer())
_, var_list_1 = meta_graph.export_scoped_meta_graph(
graph_def=graph1.as_graph_def(), export_scope="hidden1")
saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
saver.save(sess, saver0_ckpt, write_state=False)
expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))
# Verifies that we can run successfully after restoring.
graph2 = ops_lib.Graph()
with graph2.as_default():
new_var_list_1 = meta_graph.copy_scoped_meta_graph(
from_scope="hidden1",
to_scope="new_hidden1",
from_graph=graph1,
to_graph=graph2)
with self.session(graph=graph2) as sess:
saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
saver3.restore(sess, saver0_ckpt)
self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
def testSerializeSaverWithScope(self):
test_dir = self._get_test_dir("export_graph_def")
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
saver2_ckpt = os.path.join(test_dir, "saver2.ckpt")
graph = ops_lib.Graph()
with graph.as_default():
with ops_lib.name_scope("hidden1"):
variable1 = variables.VariableV1([1.0], name="variable1")
saver1 = saver_module.Saver(var_list=[variable1])
graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)
with ops_lib.name_scope("hidden2"):
variable2 = variables.VariableV1([2.0], name="variable2")
saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")
graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)
with self.session(graph=graph) as sess:
self.evaluate(variables.global_variables_initializer())
saver1.save(sess, saver1_ckpt, write_state=False)
saver2.save(sess, saver2_ckpt, write_state=False)
graph1 = ops_lib.Graph()
with graph1.as_default():
var_dict1 = meta_graph.copy_scoped_meta_graph(
from_scope="hidden1",
to_scope="new_hidden1",
from_graph=graph,
to_graph=graph1)
self.assertEqual(1, len(var_dict1))
saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS)
self.assertEqual(1, len(saver_list1))
with self.session(graph=graph1) as sess:
saver_list1[0].restore(sess, saver1_ckpt)
self.assertEqual(1.0, self.evaluate(var_dict1["variable1:0"]))
graph2 = ops_lib.Graph()
with graph2.as_default():
var_dict2 = meta_graph.copy_scoped_meta_graph(
from_scope="hidden2",
to_scope="new_hidden2",
from_graph=graph,
to_graph=graph2)
self.assertEqual(1, len(var_dict2))
saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS)
self.assertEqual(1, len(saver_list2))
with self.session(graph=graph2) as sess:
saver_list2[0].restore(sess, saver2_ckpt)
self.assertEqual(2.0, self.evaluate(var_dict2["variable2:0"]))
class _OwnsAVariableSimple(trackable_base.Trackable):
"""A Trackable object which can be saved using a tf.train.Saver."""
def __init__(self):
self.non_dep_variable = variable_scope.get_variable(
name="non_dep_variable", initializer=6., use_resource=True)
def _gather_saveables_for_checkpoint(self):
return {trackable_base.VARIABLE_VALUE_KEY: self.non_dep_variable}
# The Saver sorts by name before parsing, so we need a name property.
@property
def name(self):
return self.non_dep_variable.name
class _MirroringSaveable(
saver_module.BaseSaverBuilder.ResourceVariableSaveable):
def __init__(self, primary_variable, mirrored_variable, name):
self._primary_variable = primary_variable
self._mirrored_variable = mirrored_variable
super(_MirroringSaveable, self).__init__(
self._primary_variable, "", name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into both variables."""
tensor, = restored_tensors
return control_flow_ops.group(
self._primary_variable.assign(tensor),
self._mirrored_variable.assign(tensor))
class _OwnsMirroredVariables(trackable_base.Trackable):
"""A Trackable object which returns a more complex SaveableObject."""
def __init__(self):
self.non_dep_variable = variable_scope.get_variable(
name="non_dep_variable", initializer=6., use_resource=True)
self.mirrored = variable_scope.get_variable(
name="mirrored", initializer=15., use_resource=True)
def _gather_saveables_for_checkpoint(self):
def _saveable_factory(name=self.non_dep_variable.name):
return _MirroringSaveable(
primary_variable=self.non_dep_variable,
mirrored_variable=self.mirrored,
name=name)
return {trackable_base.VARIABLE_VALUE_KEY: _saveable_factory}
# The Saver sorts by name before parsing, so we need a name property.
@property
def name(self):
return self.non_dep_variable.name
class TrackableCompatibilityTests(test.TestCase):
# TODO(allenl): Track down python3 reference cycles in these tests.
@test_util.run_in_graph_and_eager_modes
def testNotSaveableButIsTrackable(self):
v = _OwnsAVariableSimple()
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
for saver in (saver_module.Saver(var_list=[v]),
saver_module.Saver(var_list={"v": v})):
with self.cached_session() as sess:
self.evaluate(v.non_dep_variable.assign(42.))
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
saver.restore(sess, save_path)
self.assertEqual(42., self.evaluate(v.non_dep_variable))
@test_util.run_in_graph_and_eager_modes
def testMoreComplexSaveableReturned(self):
v = _OwnsMirroredVariables()
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
self.evaluate(v.non_dep_variable.assign(42.))
for saver in (saver_module.Saver(var_list=[v]),
saver_module.Saver(var_list={"v": v})):
with self.cached_session() as sess:
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
self.evaluate(v.mirrored.assign(44.))
saver.restore(sess, save_path)
self.assertEqual(42., self.evaluate(v.non_dep_variable))
self.assertEqual(42., self.evaluate(v.mirrored))
def testSingleTensorEvaluation(self):
class _CountingSaveable(saver_module.BaseSaverBuilder.SaveableObject):
def __init__(self, name):
self.eval_count = 0
def _tensor():
self.eval_count += 1
return constant_op.constant([1.])
dummy_op = constant_op.constant([2.])
super(_CountingSaveable, self).__init__(
dummy_op,
[saver_module.BaseSaverBuilder.SaveSpec(
_tensor, "", name, dtype=dummy_op.dtype,
device=dummy_op.device)],
name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into both variables."""
pass
with context.eager_mode():
v = _CountingSaveable("foo")
saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
with self.cached_session() as sess:
save_path = saver.save(sess, prefix)
self.assertEqual(1, v.eval_count)
saver.restore(sess, save_path)
self.assertEqual(1, v.eval_count)
def testVariableNotFoundErrorRaised(self):
# Restore does some tricky exception handling to figure out if it should
# load an object-based checkpoint. Tests that the exception handling isn't
# too broad.
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
a = resource_variable_ops.ResourceVariable(1., name="a")
b = resource_variable_ops.ResourceVariable(1., name="b")
a_saver = saver_module.Saver([a])
b_saver = saver_module.Saver([b])
with self.cached_session() as sess:
self.evaluate(a.initializer)
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
with self.assertRaisesRegex(errors.NotFoundError,
"Key b not found in checkpoint"):
b_saver.restore(sess=sess, save_path=save_path)
with self.assertRaises(errors.NotFoundError) as cs:
b_saver.restore(sess=sess, save_path=save_path)
# Make sure we don't have a confusing "During handling of the above
# exception" block in Python 3.
self.assertNotIn("NewCheckpointReader", cs.exception.message)
@test_util.run_v1_only("train.Saver is V1 only API.")
def testGraphChangedForRestoreErrorRaised(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
with ops_lib.Graph().as_default() as g:
a = variables.VariableV1(1., name="a")
a_saver = saver_module.Saver([a])
with self.session(graph=g) as sess:
self.evaluate(a.initializer)
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
with ops_lib.Graph().as_default() as g:
a = variables.VariableV1([1.], name="a")
a_saver = saver_module.Saver([a])
with self.session(graph=g) as sess:
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"a mismatch between the current graph and the graph"):
a_saver.restore(sess=sess, save_path=save_path)
if __name__ == "__main__":
test.main()