Fork keras related tracking test to keras/tests
PiperOrigin-RevId: 316482123 Change-Id: I20645bbfdd926e2c83136ee27c6ef9325cb1f438
This commit is contained in:
parent
3854251adb
commit
16ac7c04d4
@ -370,6 +370,71 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tracking_test",
|
||||
srcs = ["tracking_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/layers:core",
|
||||
"//tensorflow/python/keras/layers:normalization",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tracking_util_test",
|
||||
srcs = ["tracking_util_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["notsan"], # b/74395663
|
||||
deps = [
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/python:checkpoint_management",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:template",
|
||||
"//tensorflow/python:training_util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/layers:core",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//tensorflow/python/training/tracking:graph_view",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "get_config_samples",
|
||||
srcs = ["get_config_samples.py"],
|
||||
|
610
tensorflow/python/keras/tests/tracking_test.py
Normal file
610
tensorflow/python/keras/tests/tracking_test.py
Normal file
@ -0,0 +1,610 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy
|
||||
import six
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import normalization
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
class HasList(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasList, self).__init__()
|
||||
self.layer_list = data_structures.List([core.Dense(3)])
|
||||
self.layer_list.append(core.Dense(4))
|
||||
self.layer_list.extend(
|
||||
[core.Dense(5),
|
||||
core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
|
||||
self.layer_list += [
|
||||
core.Dense(7, bias_regularizer=math_ops.reduce_sum),
|
||||
core.Dense(8)
|
||||
]
|
||||
self.layer_list += (
|
||||
data_structures.List([core.Dense(9)]) + data_structures.List(
|
||||
[core.Dense(10)]))
|
||||
self.layer_list.extend(
|
||||
data_structures.List(
|
||||
list([core.Dense(11)]) + [core.Dense(12)]))
|
||||
self.layers_with_updates = data_structures.List(
|
||||
(normalization.BatchNormalization(),))
|
||||
|
||||
def call(self, x):
|
||||
aggregation = 0.
|
||||
for l in self.layer_list:
|
||||
x = l(x)
|
||||
aggregation += math_ops.reduce_sum(x)
|
||||
bn, = self.layers_with_updates
|
||||
return bn(x) / aggregation
|
||||
|
||||
|
||||
class ListTests(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testTracking(self):
|
||||
model = HasList()
|
||||
output = model(array_ops.ones([32, 2]))
|
||||
self.assertAllEqual([32, 12], output.shape)
|
||||
self.assertEqual(11, len(model.layers))
|
||||
self.assertEqual(10, len(model.layer_list.layers))
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
model.layers,
|
||||
model.layer_list.layers + model.layers_with_updates)
|
||||
for index in range(10):
|
||||
self.assertEqual(3 + index, model.layer_list.layers[index].units)
|
||||
self.assertEqual(2, len(model._checkpoint_dependencies))
|
||||
self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
|
||||
self.assertIs(model.layers_with_updates,
|
||||
model._checkpoint_dependencies[1].ref)
|
||||
self.assertEqual(
|
||||
10, len(model._checkpoint_dependencies[0].ref._checkpoint_dependencies))
|
||||
self.evaluate([v.initializer for v in model.variables])
|
||||
self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]]))
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
|
||||
model.load_weights(save_path)
|
||||
self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
|
||||
self.evaluate(model.variables[0]))
|
||||
v = variables.Variable(1.)
|
||||
model.var_list = [v]
|
||||
self.assertIn(v, model.variables)
|
||||
self.assertIn(v, model.trainable_variables)
|
||||
self.assertNotIn(v, model.non_trainable_variables)
|
||||
self.assertIn(model.layer_list[0].trainable_weights[0],
|
||||
model.trainable_weights)
|
||||
|
||||
def testSubModelTracking(self):
|
||||
model = training.Model()
|
||||
model.v = variables.Variable(1.)
|
||||
self.assertIn(model.v, model.trainable_weights)
|
||||
model2 = training.Model()
|
||||
model2.m = [model]
|
||||
self.assertIn(model.v, model2.trainable_weights)
|
||||
|
||||
def testSubSequentialTracking(self):
|
||||
|
||||
class _Subclassed(training.Model):
|
||||
|
||||
def __init__(self, wrapped):
|
||||
super(_Subclassed, self).__init__()
|
||||
self._wrapped = wrapped
|
||||
|
||||
def call(self, x):
|
||||
return self._wrapped(x)
|
||||
|
||||
model = sequential.Sequential()
|
||||
layer = core.Dense(1)
|
||||
model.add(layer)
|
||||
model2 = _Subclassed(model)
|
||||
model2(array_ops.ones([1, 2]))
|
||||
model2.m = [model]
|
||||
self.assertIn(layer.kernel, model2.trainable_weights)
|
||||
|
||||
def testLayerTrackedThroughSequential(self):
|
||||
class AttrDict(dict):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
def ffnet(layer_sizes, name):
|
||||
ff = sequential.Sequential(name=name)
|
||||
for i, width in enumerate(layer_sizes):
|
||||
ff.add(core.Dense(
|
||||
width,
|
||||
activation=("relu" if i < len(layer_sizes)-1 else None)))
|
||||
return ff
|
||||
|
||||
class MyModel2(training.Model):
|
||||
|
||||
def __init__(self, config, name="my_model_2"):
|
||||
super(MyModel2, self).__init__(name=name)
|
||||
self._num_tokens = config.num_tokens
|
||||
|
||||
# list of sub-models
|
||||
self._ffnet = [ffnet(config.module_layers + (self._num_tokens,), "ff")]
|
||||
|
||||
def null_input(self):
|
||||
return array_ops.zeros([1, self._num_tokens], dtype=dtypes.float32)
|
||||
|
||||
def call(self, input_, module_index=None):
|
||||
return self._ffnet[0](input_)
|
||||
|
||||
m2 = MyModel2(AttrDict(
|
||||
num_tokens=5,
|
||||
module_layers=(50, 30)))
|
||||
|
||||
# Construct
|
||||
m2(m2.null_input())
|
||||
self.assertLen(m2.trainable_variables, 6)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testUpdatesForwarded(self):
|
||||
with context.graph_mode():
|
||||
model = HasList()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertGreater(len(model.layers_with_updates[0].updates), 0)
|
||||
self.assertEqual(set(model.layers_with_updates[0].updates),
|
||||
set(model.updates))
|
||||
|
||||
with context.eager_mode():
|
||||
model = HasList()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertEqual(0, len(model.updates))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testLossesForwarded(self):
|
||||
model = HasList()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertEqual(2, len(model.losses))
|
||||
|
||||
def testModelContainersCompareEqual(self):
|
||||
class HasEqualContainers(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasEqualContainers, self).__init__()
|
||||
self.l1 = []
|
||||
self.l2 = []
|
||||
|
||||
model = HasEqualContainers()
|
||||
first_layer = HasEqualContainers()
|
||||
model.l1.append(first_layer)
|
||||
second_layer = HasEqualContainers()
|
||||
model.l2.append(second_layer)
|
||||
self.assertEqual([first_layer, second_layer], model.layers)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testTensorConversion(self):
|
||||
|
||||
class ListToTensor(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(ListToTensor, self).__init__()
|
||||
self.l = [1., 2., 3.]
|
||||
|
||||
self.assertAllEqual(
|
||||
[1., 2., 3.],
|
||||
self.evaluate(constant_op.constant(ListToTensor().l)))
|
||||
|
||||
self.assertAllEqual(
|
||||
[1., 2., 3.],
|
||||
self.evaluate(array_ops.pack(ListToTensor().l)))
|
||||
|
||||
|
||||
class ListWrapperTest(test.TestCase):
|
||||
|
||||
def testLayerCollectionWithExternalMutation(self):
|
||||
l = []
|
||||
l_wrapper = data_structures.ListWrapper(l)
|
||||
layer = core.Dense(1)
|
||||
l.append(layer)
|
||||
self.assertEqual([layer], l_wrapper.layers)
|
||||
|
||||
|
||||
class HasMapping(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasMapping, self).__init__()
|
||||
self.layer_dict = data_structures.Mapping(output=core.Dense(7))
|
||||
self.layer_dict["norm"] = data_structures.List()
|
||||
self.layer_dict["dense"] = data_structures.List()
|
||||
self.layer_dict["dense"].extend(
|
||||
[core.Dense(5),
|
||||
core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
|
||||
self.layer_dict["norm"].append(
|
||||
normalization.BatchNormalization())
|
||||
self.layer_dict["norm"].append(
|
||||
normalization.BatchNormalization())
|
||||
|
||||
def call(self, x):
|
||||
aggregation = 0.
|
||||
for norm, dense in zip(self.layer_dict["norm"], self.layer_dict["dense"]):
|
||||
x = norm(dense(x))
|
||||
aggregation += math_ops.reduce_sum(x)
|
||||
return self.layer_dict["output"](x) / aggregation
|
||||
|
||||
|
||||
class MappingTests(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testTracking(self):
|
||||
model = HasMapping()
|
||||
output = model(array_ops.ones([32, 2]))
|
||||
self.assertAllEqual([32, 7], output.shape.as_list())
|
||||
self.assertEqual(5, len(model.layers))
|
||||
six.assertCountEqual(self, model.layers, model.layer_dict.layers)
|
||||
self.assertEqual(1, len(model._checkpoint_dependencies))
|
||||
self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref)
|
||||
self.evaluate([v.initializer for v in model.variables])
|
||||
test_var = model.layer_dict["output"].kernel
|
||||
self.evaluate(test_var.assign(array_ops.ones([6, 7])))
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
self.evaluate(test_var.assign(array_ops.zeros([6, 7])))
|
||||
model.load_weights(save_path)
|
||||
self.assertAllEqual(numpy.ones([6, 7]),
|
||||
self.evaluate(test_var))
|
||||
|
||||
def testLayerCollectionWithExternalMutation(self):
|
||||
d = {}
|
||||
root = tracking.AutoTrackable()
|
||||
root.wrapper = d
|
||||
self.assertEqual([], root.wrapper.layers)
|
||||
self.assertEqual([], root.wrapper.trainable_weights)
|
||||
layer1 = core.Dense(1)
|
||||
layer2 = core.Dense(1)
|
||||
d["a"] = layer1
|
||||
d["b"] = layer2
|
||||
self.assertEqual([layer1, layer2], root.wrapper.layers)
|
||||
# The layers have still not created variables
|
||||
self.assertEqual([], root.wrapper.trainable_weights)
|
||||
|
||||
def testDictWrapperBadKeys(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
a.d[1] = data_structures.List()
|
||||
model = training.Model()
|
||||
model.sub = a
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.assertRaisesRegexp(ValueError, "non-string key"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testDictWrapperNoDependency(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = data_structures.NoDependency({})
|
||||
a.d[1] = [3]
|
||||
self.assertEqual([a], util.list_objects(a))
|
||||
model = training.Model()
|
||||
model.sub = a
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testNonStringKeyNotTrackableValue(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
a.d["a"] = [3]
|
||||
a.d[1] = data_structures.NoDependency([3])
|
||||
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
||||
model = training.Model()
|
||||
model.sub = a
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testNonAppendNotTrackable(self):
|
||||
# Non-append mutations (deleting or overwriting values) are OK when the
|
||||
# values aren't tracked.
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
a.d["a"] = [3]
|
||||
a.d[1] = 3
|
||||
a.d[1] = 2
|
||||
self.assertEqual(2, a.d[1])
|
||||
del a.d[1]
|
||||
a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
|
||||
second = tracking.AutoTrackable()
|
||||
a.d[2] = data_structures.NoDependency(second)
|
||||
self.assertIs(second, a.d[2])
|
||||
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
||||
model = training.Model()
|
||||
model.sub = a
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testPopNoSave(self):
|
||||
model = training.Model()
|
||||
model.d = {}
|
||||
model.d["a"] = []
|
||||
model.d.pop("a")
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.assertRaisesRegexp(ValueError, "Unable to save"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testExternalModificationNoSave(self):
|
||||
model = training.Model()
|
||||
external_reference = {}
|
||||
model.d = external_reference
|
||||
external_reference["a"] = []
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testOverwriteCanStillSave(self):
|
||||
model = training.Model()
|
||||
model.d = {}
|
||||
model.d["a"] = {}
|
||||
model.d["a"] = {}
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testIter(self):
|
||||
model = training.Model()
|
||||
model.d = {1: 3}
|
||||
model.d[1] = 3
|
||||
self.assertEqual([1], list(model.d))
|
||||
new_dict = {}
|
||||
# This update() is super tricky. If the dict wrapper subclasses dict,
|
||||
# CPython will access its storage directly instead of calling any
|
||||
# methods/properties on the object. So the options are either not to
|
||||
# subclass dict (in which case update will call normal iter methods, but the
|
||||
# object won't pass isinstance checks) or to subclass dict and keep that
|
||||
# storage updated (no shadowing all its methods like ListWrapper).
|
||||
new_dict.update(model.d)
|
||||
self.assertEqual({1: 3}, new_dict)
|
||||
|
||||
|
||||
class HasTuple(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasTuple, self).__init__()
|
||||
self.layer_list = (
|
||||
core.Dense(3), core.Dense(4),
|
||||
core.Dense(5, kernel_regularizer=math_ops.reduce_sum))
|
||||
self.layers_with_updates = (normalization.BatchNormalization(),)
|
||||
|
||||
def call(self, x):
|
||||
aggregation = 0.
|
||||
for l in self.layer_list:
|
||||
x = l(x)
|
||||
aggregation += math_ops.reduce_sum(x)
|
||||
bn, = self.layers_with_updates
|
||||
return bn(x) / aggregation
|
||||
|
||||
|
||||
class TupleTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testTracking(self):
|
||||
model = HasTuple()
|
||||
output = model(array_ops.ones([32, 2]))
|
||||
self.assertAllEqual([32, 5], output.shape.as_list())
|
||||
self.assertLen(model.layers, 4)
|
||||
self.assertLen(model.layer_list.layers, 3)
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
model.layers,
|
||||
tuple(model.layer_list.layers) + model.layers_with_updates)
|
||||
self.assertEqual(3, model.layer_list.layers[0].units)
|
||||
self.assertEqual(4, model.layer_list.layers[1].units)
|
||||
self.assertEqual(5, model.layer_list.layers[2].units)
|
||||
self.assertLen(model._checkpoint_dependencies, 2)
|
||||
self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
|
||||
self.assertIs(model.layers_with_updates,
|
||||
model._checkpoint_dependencies[1].ref)
|
||||
self.assertLen(
|
||||
model._checkpoint_dependencies[0].ref._checkpoint_dependencies, 3)
|
||||
self.evaluate([v.initializer for v in model.variables])
|
||||
self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]]))
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
|
||||
model.load_weights(save_path)
|
||||
self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
|
||||
self.evaluate(model.variables[0]))
|
||||
v = variables.Variable(1.)
|
||||
model.var_list = (v,)
|
||||
self.assertIn(id(v), [id(obj) for obj in model.variables])
|
||||
self.assertIn(id(v), [id(obj) for obj in model.trainable_variables])
|
||||
self.assertNotIn(id(v), [id(obj) for obj in model.non_trainable_variables])
|
||||
self.assertIn(id(model.layer_list[0].trainable_weights[0]),
|
||||
[id(obj) for obj in model.trainable_weights])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Module", module.Module),
|
||||
("Model", training.Model),
|
||||
)
|
||||
def testSubModelTracking(self, module_subclass):
|
||||
model = module_subclass()
|
||||
model.v = variables.Variable(1.)
|
||||
self.assertIn(model.v, model.trainable_variables)
|
||||
model2 = module_subclass()
|
||||
model2.m = (model,)
|
||||
self.assertIn(model.v, model2.trainable_variables)
|
||||
|
||||
def testSubSequentialTracking(self):
|
||||
|
||||
class _Subclassed(training.Model):
|
||||
|
||||
def __init__(self, wrapped):
|
||||
super(_Subclassed, self).__init__()
|
||||
self._wrapped = wrapped
|
||||
|
||||
def call(self, x):
|
||||
return self._wrapped(x)
|
||||
|
||||
model = sequential.Sequential()
|
||||
layer = core.Dense(1)
|
||||
model.add(layer)
|
||||
model2 = _Subclassed(model)
|
||||
model2(array_ops.ones([1, 2]))
|
||||
model2.m = (model,)
|
||||
self.assertIn(layer.kernel, model2.trainable_weights)
|
||||
|
||||
def testUpdatesForwarded(self):
|
||||
with ops.Graph().as_default():
|
||||
model = HasTuple()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertNotEmpty(model.layers_with_updates[0].updates)
|
||||
self.assertEqual(set(model.layers_with_updates[0].updates),
|
||||
set(model.updates))
|
||||
|
||||
model = HasTuple()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertEmpty(model.updates)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLossesForwarded(self):
|
||||
model = HasTuple()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertLen(model.losses, 1)
|
||||
|
||||
def testModelContainersCompareEqual(self):
|
||||
class HasEqualContainers(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasEqualContainers, self).__init__()
|
||||
self.l1 = ()
|
||||
self.l2 = ()
|
||||
|
||||
model = HasEqualContainers()
|
||||
first_layer = HasEqualContainers()
|
||||
model.l1 = (first_layer,)
|
||||
second_layer = HasEqualContainers()
|
||||
model.l2 = (second_layer,)
|
||||
self.assertEqual((first_layer,), model.l1)
|
||||
d = {model.l1: 1, model.l2: 2}
|
||||
self.assertEqual(1, d[model.l1])
|
||||
self.assertEqual(1, d[(first_layer,)])
|
||||
self.assertEqual(2, d[model.l2])
|
||||
self.assertEqual(2, d[(second_layer,)])
|
||||
self.assertEqual([first_layer, second_layer], model.layers)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testTensorConversion(self):
|
||||
|
||||
class TupleToTensor(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(TupleToTensor, self).__init__()
|
||||
self.l = (1., 2., 3.)
|
||||
|
||||
self.assertAllEqual(
|
||||
(1., 2., 3.),
|
||||
self.evaluate(constant_op.constant(TupleToTensor().l)))
|
||||
|
||||
self.assertAllEqual(
|
||||
(1., 2., 3.),
|
||||
self.evaluate(array_ops.pack(TupleToTensor().l)))
|
||||
|
||||
|
||||
class InterfaceTests(test.TestCase):
|
||||
|
||||
def testNoDependency(self):
|
||||
root = tracking.AutoTrackable()
|
||||
hasdep = tracking.AutoTrackable()
|
||||
root.hasdep = hasdep
|
||||
nodep = tracking.AutoTrackable()
|
||||
root.nodep = data_structures.NoDependency(nodep)
|
||||
self.assertEqual(1, len(root._checkpoint_dependencies))
|
||||
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
|
||||
self.assertIs(root.hasdep, hasdep)
|
||||
self.assertIs(root.nodep, nodep)
|
||||
|
||||
class NoDependencyModel(training.Model):
|
||||
|
||||
@base.no_automatic_dependency_tracking
|
||||
def __init__(self):
|
||||
super(NoDependencyModel, self).__init__()
|
||||
self.a = []
|
||||
self.b = tracking.AutoTrackable()
|
||||
|
||||
nodeps = NoDependencyModel()
|
||||
self.assertEqual([nodeps], util.list_objects(nodeps))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDictionariesBasic(self):
|
||||
a = training.Model()
|
||||
b = training.Model()
|
||||
a.attribute = {"b": b}
|
||||
c = training.Model()
|
||||
a.attribute["c"] = []
|
||||
a.attribute["c"].append(c)
|
||||
a_deps = util.list_objects(a)
|
||||
self.assertIn(b, a_deps)
|
||||
self.assertIn(c, a_deps)
|
||||
self.assertIs(b, a.attribute["b"])
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
["b", "c"],
|
||||
[dep.name for dep in a.attribute._checkpoint_dependencies])
|
||||
self.assertEqual([b, c], a.layers)
|
||||
self.assertEqual([b, c], a.attribute.layers)
|
||||
self.assertEqual([c], a.attribute["c"].layers)
|
||||
checkpoint = util.Checkpoint(a=a)
|
||||
save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
with self.cached_session():
|
||||
checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testNoDepList(self):
|
||||
a = training.Model()
|
||||
a.l1 = data_structures.NoDependency([])
|
||||
a.l1.insert(1, 0)
|
||||
self.assertIsInstance(a.l1, list)
|
||||
checkpoint = util.Checkpoint(a=a)
|
||||
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
a.l2 = []
|
||||
a.l2.insert(1, module.Module())
|
||||
with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
|
||||
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
926
tensorflow/python/keras/tests/tracking_util_test.py
Normal file
926
tensorflow/python/keras/tests/tracking_util_test.py
Normal file
@ -0,0 +1,926 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
import weakref
|
||||
|
||||
from absl.testing import parameterized
|
||||
import six
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import template
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
# pylint: disable=not-callable
|
||||
class MyModel(training.Model):
|
||||
"""A concrete Model for testing."""
|
||||
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self._named_dense = core.Dense(1, use_bias=True)
|
||||
self._second = core.Dense(1, use_bias=False)
|
||||
# We can still track Trackables which aren't Layers.
|
||||
self._non_layer = NonLayerTrackable()
|
||||
|
||||
def call(self, values):
|
||||
ret = self._second(self._named_dense(values))
|
||||
return ret
|
||||
|
||||
|
||||
class NonLayerTrackable(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
super(NonLayerTrackable, self).__init__()
|
||||
self.a_variable = trackable_utils.add_variable(
|
||||
self, name="a_variable", shape=[])
|
||||
|
||||
|
||||
class InterfaceTests(test.TestCase):
|
||||
|
||||
def testLayerDeduplication(self):
|
||||
model = training.Model()
|
||||
layer_one = core.Dense(1)
|
||||
layer_two = core.Dense(1)
|
||||
model.other_path = [layer_one, layer_two]
|
||||
model.l2 = layer_two
|
||||
model.l1 = layer_one
|
||||
self.assertEqual([layer_one, layer_two], model.layers)
|
||||
|
||||
def testSaveWithOnlyKerasSession(self):
|
||||
|
||||
with ops.Graph().as_default():
|
||||
inp = input_layer.Input([1])
|
||||
dense = core.Dense(1)(inp)
|
||||
model = training.Model(inp, dense)
|
||||
model.compile(optimizer="sgd", loss="mse")
|
||||
model.fit([1.], [2.])
|
||||
checkpoint = trackable_utils.Checkpoint(model=model)
|
||||
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
|
||||
def testObjectMetadata(self):
|
||||
with context.eager_mode():
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
dense = core.Dense(1)
|
||||
checkpoint = trackable_utils.Checkpoint(dense=dense)
|
||||
dense(constant_op.constant([[1.]]))
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
|
||||
objects = trackable_utils.object_metadata(save_path)
|
||||
all_variable_names = []
|
||||
for obj in objects.nodes:
|
||||
for attribute in obj.attributes:
|
||||
all_variable_names.append(attribute.full_name)
|
||||
self.assertIn("dense/kernel", all_variable_names)
|
||||
|
||||
|
||||
class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testNamingWithOptimizer(self):
|
||||
input_value = constant_op.constant([[3.]])
|
||||
model = MyModel()
|
||||
# A nuisance Model using the same optimizer. Its slot variables should not
|
||||
# go in the checkpoint, since it is never depended on.
|
||||
other_model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
step = training_util.get_or_create_global_step()
|
||||
root_trackable = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model, step=step)
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
train_op = control_flow_ops.group(
|
||||
optimizer.apply_gradients(zip(gradients, variables)),
|
||||
step.assign_add(1))
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = other_model(input_value)
|
||||
variables = other_model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
optimizer.apply_gradients(zip(gradients, variables))
|
||||
|
||||
self.evaluate(trackable_utils.gather_initializers(
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
|
||||
root_trackable).serialize_object_graph()
|
||||
expected_slot_keys = (
|
||||
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
|
||||
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
|
||||
"model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
|
||||
"model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
|
||||
"model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
|
||||
"model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
|
||||
)
|
||||
expected_checkpoint_names = (
|
||||
# Created in the root node, so no prefix.
|
||||
"step",
|
||||
"model/_second/kernel",
|
||||
"model/_named_dense/kernel",
|
||||
"model/_named_dense/bias",
|
||||
# non-Layer dependency of the model
|
||||
"model/_non_layer/a_variable",
|
||||
"optimizer/learning_rate",
|
||||
"optimizer/beta_1",
|
||||
"optimizer/beta_2",
|
||||
"optimizer/iter",
|
||||
"optimizer/decay",
|
||||
) + expected_slot_keys
|
||||
suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
expected_checkpoint_names = [
|
||||
name + suffix for name in expected_checkpoint_names]
|
||||
named_variables = {v.name: v for v in named_variables}
|
||||
six.assertCountEqual(self, expected_checkpoint_names,
|
||||
named_variables.keys())
|
||||
# Check that we've mapped to the right variable objects (not exhaustive)
|
||||
self.assertEqual(
|
||||
"global_step",
|
||||
named_variables["step" + suffix].full_name)
|
||||
self.assertEqual(
|
||||
"my_model/dense_1/kernel",
|
||||
named_variables["model/_second/kernel" + suffix].full_name)
|
||||
self.assertEqual(
|
||||
"my_model/dense/kernel",
|
||||
named_variables["model/_named_dense/kernel" + suffix].full_name)
|
||||
self.assertEqual("Adam/beta_1",
|
||||
named_variables["optimizer/beta_1" + suffix].full_name)
|
||||
self.assertEqual("Adam/beta_2",
|
||||
named_variables["optimizer/beta_2" + suffix].full_name)
|
||||
# Spot check the generated protocol buffers.
|
||||
self.assertEqual("optimizer",
|
||||
serialized_graph.nodes[0].children[1].local_name)
|
||||
optimizer_node = serialized_graph.nodes[
|
||||
serialized_graph.nodes[0].children[1].node_id]
|
||||
children = [node.local_name for node in optimizer_node.children]
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
# hyper variable dependencies
|
||||
["beta_1", "beta_2", "iter", "decay", "learning_rate"],
|
||||
children)
|
||||
serialized_slot_keys = []
|
||||
for slot in optimizer_node.slot_variables:
|
||||
for attribute in (
|
||||
serialized_graph.nodes[slot.slot_variable_node_id].attributes):
|
||||
serialized_slot_keys.append(attribute.checkpoint_key)
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
[key + suffix for key in expected_slot_keys],
|
||||
serialized_slot_keys)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSaveRestore(self):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root_trackable = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
train_op = optimizer.apply_gradients(zip(gradients, variables))
|
||||
self.assertFalse(root_trackable.save_counter.trainable)
|
||||
self.evaluate(trackable_utils.gather_initializers(
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
|
||||
m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
|
||||
self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
|
||||
save_path = root_trackable.save(file_prefix=prefix)
|
||||
self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
|
||||
self.evaluate(state_ops.assign(root_trackable.save_counter, 3))
|
||||
optimizer_variables = self.evaluate(
|
||||
sorted(optimizer.variables(), key=lambda v: v.name))
|
||||
self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
|
||||
# Immediate restoration
|
||||
status = root_trackable.restore(save_path=save_path).assert_consumed()
|
||||
status.run_restore_ops()
|
||||
self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
|
||||
self.assertAllEqual(1, self.evaluate(root_trackable.save_counter))
|
||||
self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
|
||||
if not context.executing_eagerly():
|
||||
return # Restore-on-create is only supported when executing eagerly
|
||||
on_create_model = MyModel()
|
||||
on_create_optimizer = adam.Adam(0.001)
|
||||
on_create_root = trackable_utils.Checkpoint(
|
||||
optimizer=on_create_optimizer, model=on_create_model)
|
||||
# Deferred restoration
|
||||
status = on_create_root.restore(save_path=save_path)
|
||||
status.assert_nontrivial_match()
|
||||
status.assert_existing_objects_matched()
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_consumed()
|
||||
on_create_model(constant_op.constant([[3.]])) # create variables
|
||||
self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
|
||||
self.assertAllEqual([42.],
|
||||
self.evaluate(
|
||||
on_create_model._named_dense.variables[1]))
|
||||
on_create_m_bias_slot = on_create_optimizer.get_slot(
|
||||
on_create_model._named_dense.variables[1], "m")
|
||||
status.assert_existing_objects_matched()
|
||||
if not context.executing_eagerly():
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_consumed()
|
||||
# Optimizer slot variables are created when the original variable is
|
||||
# restored.
|
||||
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
|
||||
dummy_var = resource_variable_ops.ResourceVariable([1.])
|
||||
on_create_optimizer.minimize(loss=dummy_var.read_value,
|
||||
var_list=[dummy_var])
|
||||
status.assert_existing_objects_matched()
|
||||
status.assert_consumed()
|
||||
self.assertAllEqual(
|
||||
optimizer_variables,
|
||||
# Creation order is different, so .variables() needs to be re-sorted.
|
||||
self.evaluate(sorted(optimizer.variables(), key=lambda v: v.name)))
|
||||
|
||||
# TODO(allenl): Debug garbage created by this test in python3.
|
||||
def testDeferredRestorationUsageEager(self):
|
||||
"""An idiomatic eager execution example."""
|
||||
num_training_steps = 10
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
for training_continuation in range(3):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
root.restore(checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory))
|
||||
for _ in range(num_training_steps):
|
||||
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
|
||||
input_value = constant_op.constant([[3.]])
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
optimizer.apply_gradients(zip(gradients, variables))
|
||||
root.save(file_prefix=checkpoint_prefix)
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
root.optimizer.iterations.numpy())
|
||||
|
||||
def testUsageGraph(self):
|
||||
"""Expected usage when graph building."""
|
||||
with context.graph_mode():
|
||||
num_training_steps = 10
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
for training_continuation in range(3):
|
||||
with ops.Graph().as_default():
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.CheckpointV1(
|
||||
optimizer=optimizer, model=model)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
train_op = optimizer.apply_gradients(zip(gradients, variables))
|
||||
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
with self.session(graph=ops.get_default_graph()) as session:
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
status.initialize_or_restore(session=session)
|
||||
if checkpoint_path is None:
|
||||
self.assertEqual(0, training_continuation)
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_consumed()
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_existing_objects_matched()
|
||||
else:
|
||||
status.assert_consumed()
|
||||
status.assert_existing_objects_matched()
|
||||
for _ in range(num_training_steps):
|
||||
session.run(train_op)
|
||||
root.save(file_prefix=checkpoint_prefix, session=session)
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
session.run(root.optimizer.iterations))
|
||||
self.assertEqual(training_continuation + 1,
|
||||
session.run(root.save_counter))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testAgnosticUsage(self):
|
||||
"""Graph/eager agnostic usage."""
|
||||
# Does create garbage when executing eagerly due to ops.Graph() creation.
|
||||
num_training_steps = 10
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
def _train_fn(model, input_value):
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
return optimizer.apply_gradients(zip(gradients, variables))
|
||||
for training_continuation in range(3):
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
root, checkpoint_directory, max_to_keep=1)
|
||||
status = root.restore(save_path=manager.latest_checkpoint)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
train_fn = functools.partial(_train_fn, model, input_value)
|
||||
if not context.executing_eagerly():
|
||||
train_fn = functools.partial(self.evaluate, train_fn())
|
||||
status.initialize_or_restore()
|
||||
for _ in range(num_training_steps):
|
||||
train_fn()
|
||||
manager.save()
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
self.evaluate(root.optimizer.iterations))
|
||||
self.assertEqual(training_continuation + 1,
|
||||
self.evaluate(root.save_counter))
|
||||
|
||||
def testPartialRestoreWarningObject(self):
|
||||
with context.eager_mode():
|
||||
optimizer = adam.Adam(0.0)
|
||||
original_root = trackable_utils.Checkpoint(v1=variables_lib.Variable(2.),
|
||||
v2=variables_lib.Variable(3.),
|
||||
optimizer=optimizer)
|
||||
# Create a slot variable to save
|
||||
optimizer.minimize(original_root.v1.read_value, [original_root.v1])
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
save_path = original_root.save(prefix)
|
||||
partial_root = trackable_utils.Checkpoint(v1=variables_lib.Variable(0.))
|
||||
weak_partial_root = weakref.ref(partial_root)
|
||||
weak_v1 = weakref.ref(partial_root.v1)
|
||||
partial_root.restore(save_path)
|
||||
self.assertEqual(2., partial_root.v1.numpy())
|
||||
with test.mock.patch.object(logging, "warning") as mock_log:
|
||||
del partial_root
|
||||
self.assertIsNone(weak_partial_root())
|
||||
self.assertIsNone(weak_v1())
|
||||
messages = str(mock_log.call_args_list)
|
||||
self.assertIn("(root).v2'", messages)
|
||||
self.assertIn("(root).optimizer's state 'm' for (root).v1", messages)
|
||||
self.assertNotIn("(root).v1'", messages)
|
||||
self.assertIn("expect_partial()", messages)
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testWithDefun(self):
|
||||
num_training_steps = 2
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
for training_continuation in range(3):
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
# Don't actually train so we can test variable values
|
||||
optimizer = adam.Adam(0.)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
def train_fn():
|
||||
@def_function.function
|
||||
def _call_model(x):
|
||||
return model(x)
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = _call_model(constant_op.constant([[3.]]))
|
||||
gradients = tape.gradient(loss, model.variables)
|
||||
return optimizer.apply_gradients(zip(gradients, model.variables))
|
||||
if not context.executing_eagerly():
|
||||
train_fn = functools.partial(
|
||||
self.evaluate, train_fn())
|
||||
status.initialize_or_restore()
|
||||
for _ in range(num_training_steps):
|
||||
train_fn()
|
||||
if training_continuation > 0:
|
||||
status.assert_consumed()
|
||||
self.assertAllClose([[42.]], self.evaluate(model.variables[0]))
|
||||
else:
|
||||
self.evaluate(model.variables[0].assign([[42.]]))
|
||||
root.save(file_prefix=checkpoint_prefix)
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
self.evaluate(optimizer.iterations))
|
||||
self.assertEqual(training_continuation + 1,
|
||||
self.evaluate(root.save_counter))
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
def testAnonymousVarsInInit(self):
|
||||
|
||||
class Model(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.w = resource_variable_ops.ResourceVariable(0.0)
|
||||
self.b = resource_variable_ops.ResourceVariable(0.0)
|
||||
self.vars = [self.w, self.b]
|
||||
|
||||
def call(self, x):
|
||||
return x * self.w + self.b
|
||||
|
||||
with context.eager_mode():
|
||||
model = Model()
|
||||
optimizer = adam.Adam(learning_rate=0.05)
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
checkpoint = trackable_utils.Checkpoint(
|
||||
model=model, optimizer=optimizer)
|
||||
for _ in range(2):
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = (constant_op.constant(1.)
|
||||
- model(constant_op.constant(1.))) ** 2
|
||||
grad = tape.gradient(loss, model.vars)
|
||||
optimizer.apply_gradients(
|
||||
[(g, v) for g, v in zip(grad, model.vars)])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDeferredSlotRestoration(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
|
||||
root = trackable_utils.Checkpoint()
|
||||
root.var = trackable_utils.add_variable(
|
||||
root, name="var", initializer=0.)
|
||||
optimizer = adam.Adam(0.1)
|
||||
variables = [root.var]
|
||||
gradients = [1.]
|
||||
train_op = optimizer.apply_gradients(zip(gradients, variables))
|
||||
# Note that `optimizer` has not been added as a dependency of
|
||||
# `root`. Create a one-off grouping so that slot variables for `root.var`
|
||||
# get initialized too.
|
||||
self.evaluate(trackable_utils.gather_initializers(
|
||||
trackable_utils.Checkpoint(root=root, optimizer=optimizer)))
|
||||
self.evaluate(train_op)
|
||||
self.evaluate(state_ops.assign(root.var, 12.))
|
||||
no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
|
||||
root.optimizer = optimizer
|
||||
self.evaluate(state_ops.assign(root.var, 13.))
|
||||
self.evaluate(state_ops.assign(
|
||||
optimizer.get_slot(slot_name="m", var=root.var),
|
||||
14.))
|
||||
slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
|
||||
new_root = trackable_utils.Checkpoint()
|
||||
# Load the slot-containing checkpoint (deferred), then immediately overwrite
|
||||
# the non-slot variable (also deferred).
|
||||
slot_status = new_root.restore(slots_path)
|
||||
no_slot_status = new_root.restore(no_slots_path)
|
||||
with self.assertRaises(AssertionError):
|
||||
no_slot_status.assert_consumed()
|
||||
new_root.var = trackable_utils.add_variable(
|
||||
new_root, name="var", shape=[])
|
||||
no_slot_status.assert_consumed()
|
||||
no_slot_status.run_restore_ops()
|
||||
self.assertEqual(12., self.evaluate(new_root.var))
|
||||
new_root.optimizer = adam.Adam(0.1)
|
||||
slot_status.assert_existing_objects_matched()
|
||||
if not context.executing_eagerly():
|
||||
with self.assertRaisesRegexp(AssertionError, "Unresolved object"):
|
||||
slot_status.assert_consumed()
|
||||
self.assertEqual(12., self.evaluate(new_root.var))
|
||||
if context.executing_eagerly():
|
||||
# Slot variables are only created with restoring initializers when
|
||||
# executing eagerly.
|
||||
self.assertEqual(14., self.evaluate(
|
||||
new_root.optimizer.get_slot(slot_name="m", var=new_root.var)))
|
||||
else:
|
||||
# Slot variables are not created eagerly when graph building.
|
||||
with self.assertRaises(KeyError):
|
||||
new_root.optimizer.get_slot(slot_name="m", var=new_root.var)
|
||||
variables = [new_root.var]
|
||||
gradients = [1.]
|
||||
train_op = new_root.optimizer.apply_gradients(zip(gradients, variables))
|
||||
# The slot variable now exists; restore() didn't create it, but we should
|
||||
# now have a restore op for it.
|
||||
slot_status.run_restore_ops()
|
||||
if not context.executing_eagerly():
|
||||
# The train op hasn't run when graph building, so the slot variable has
|
||||
# its restored value. It has run in eager, so the value will be different.
|
||||
self.assertEqual(14., self.evaluate(
|
||||
new_root.optimizer.get_slot(slot_name="m", var=new_root.var)))
|
||||
self.evaluate(train_op)
|
||||
slot_status.assert_consumed()
|
||||
|
||||
def testManySavesGraph(self):
|
||||
"""Saves after the first should not modify the graph."""
|
||||
with context.graph_mode():
|
||||
graph = ops.Graph()
|
||||
with graph.as_default(), self.session(graph):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
obj = trackable_utils.Checkpoint()
|
||||
obj.var = variables_lib.Variable(0., name="v")
|
||||
obj.opt = adam.Adam(0.1)
|
||||
variables = [obj.var]
|
||||
gradients = [1.]
|
||||
obj.opt.apply_gradients(zip(gradients, variables))
|
||||
self.evaluate(trackable_utils.gather_initializers(obj))
|
||||
obj.save(checkpoint_prefix)
|
||||
graph.finalize()
|
||||
obj.save(checkpoint_prefix)
|
||||
|
||||
def testManyRestoresGraph(self):
|
||||
"""Restores after the first should not modify the graph."""
|
||||
with context.graph_mode():
|
||||
graph = ops.Graph()
|
||||
with graph.as_default(), self.session(graph):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
obj = trackable_utils.Checkpoint()
|
||||
obj.var = variables_lib.Variable(0., name="v")
|
||||
obj.opt = adam.Adam(0.1)
|
||||
variables = [obj.var]
|
||||
gradients = [1.]
|
||||
obj.opt.apply_gradients(zip(gradients, variables))
|
||||
self.evaluate(trackable_utils.gather_initializers(obj))
|
||||
save_path = obj.save(checkpoint_prefix)
|
||||
obj.restore(save_path)
|
||||
graph.finalize()
|
||||
obj.restore(save_path)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_sequential(self):
|
||||
model = sequential.Sequential()
|
||||
checkpoint = trackable_utils.Checkpoint(model=model)
|
||||
model.add(core.Dense(4))
|
||||
second_dense = core.Dense(5)
|
||||
model.add(second_dense)
|
||||
model(constant_op.constant([[1.]]))
|
||||
checkpoint.restore(None).initialize_or_restore()
|
||||
self.evaluate(second_dense.bias.assign(
|
||||
constant_op.constant([1., 2., 3., 4., 5.])))
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
self.evaluate(second_dense.bias.assign(
|
||||
constant_op.constant([5., 6., 7., 8., 9.])))
|
||||
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
|
||||
self.assertAllEqual([1., 2., 3., 4., 5.], self.evaluate(second_dense.bias))
|
||||
|
||||
deferred_sequential = sequential.Sequential()
|
||||
deferred_sequential_checkpoint = trackable_utils.Checkpoint(
|
||||
model=deferred_sequential)
|
||||
status = deferred_sequential_checkpoint.restore(save_path)
|
||||
deferred_sequential.add(core.Dense(4))
|
||||
deferred_second_dense = core.Dense(5)
|
||||
deferred_sequential.add(deferred_second_dense)
|
||||
deferred_sequential(constant_op.constant([[1.]]))
|
||||
status.run_restore_ops()
|
||||
self.assertAllEqual([1., 2., 3., 4., 5.],
|
||||
self.evaluate(deferred_second_dense.bias))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_initialize_if_not_restoring(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.Checkpoint(
|
||||
model=model) # Do not save the optimizer with the checkpoint.
|
||||
optimizer_checkpoint = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer)
|
||||
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
def train_fn():
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
return optimizer.apply_gradients(zip(gradients, variables))
|
||||
if not context.executing_eagerly():
|
||||
train_fn = functools.partial(self.evaluate, train_fn())
|
||||
status.initialize_or_restore()
|
||||
# TODO(tanzheny): Add hyper variables to .variables(), and set them with
|
||||
# set_weights etc.
|
||||
variables_not_in_the_variables_property = [
|
||||
obj for obj in optimizer._hyper.values()
|
||||
if isinstance(obj, variables_lib.Variable)]
|
||||
self.evaluate([v.initializer for v
|
||||
in optimizer.variables()
|
||||
+ variables_not_in_the_variables_property])
|
||||
train_fn()
|
||||
model_save_path = root.save(file_prefix=checkpoint_prefix)
|
||||
self.evaluate(optimizer.beta_1.assign(42.))
|
||||
optimizer_save_path = optimizer_checkpoint.save(optimizer_only_prefix)
|
||||
del train_fn
|
||||
|
||||
# Restore into a graph with the optimizer
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
status = root.restore(save_path=model_save_path)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
def train_fn1():
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
return optimizer.apply_gradients(zip(gradients, variables))
|
||||
if not context.executing_eagerly():
|
||||
train_fn1 = functools.partial(self.evaluate, train_fn1())
|
||||
status.initialize_or_restore()
|
||||
train_fn1()
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_existing_objects_matched()
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_consumed()
|
||||
del train_fn1
|
||||
|
||||
# Make sure initialization doesn't clobber later restores
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001, beta_1=1.0)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
opt_root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer)
|
||||
status = root.restore(save_path=model_save_path)
|
||||
init_only_optimizer_status = opt_root.restore(save_path=None)
|
||||
optimizer_status = opt_root.restore(save_path=optimizer_save_path)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
def train_fn2():
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
return optimizer.apply_gradients(zip(gradients, variables))
|
||||
if not context.executing_eagerly():
|
||||
train_fn2 = functools.partial(self.evaluate, train_fn2())
|
||||
optimizer_status.run_restore_ops()
|
||||
status.initialize_or_restore()
|
||||
init_only_optimizer_status.initialize_or_restore()
|
||||
train_fn2()
|
||||
self.assertEqual(42., self.evaluate(optimizer.beta_1))
|
||||
|
||||
|
||||
class _ManualScope(tracking.AutoTrackable):
|
||||
|
||||
def __call__(self):
|
||||
with variable_scope.variable_scope("ManualScope") as vs:
|
||||
self.variable_scope = vs
|
||||
with trackable_utils.capture_dependencies(template=self):
|
||||
return self._build()
|
||||
|
||||
def _build(self):
|
||||
return variable_scope.get_variable(name="in_manual_scope", shape=[])
|
||||
|
||||
|
||||
class TemplateTests(parameterized.TestCase, test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_trackable_save_restore(self):
|
||||
|
||||
def _templated():
|
||||
v = variable_scope.get_variable(
|
||||
"v", shape=[1], initializer=init_ops.zeros_initializer(),
|
||||
use_resource=True)
|
||||
v2 = variable_scope.get_variable(
|
||||
"v2", shape=[1], initializer=init_ops.zeros_initializer(),
|
||||
use_resource=True)
|
||||
manual = _ManualScope()
|
||||
return v, v + 1., v2, manual, manual()
|
||||
|
||||
save_template = template.make_template("s1", _templated)
|
||||
v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
[id(v1_save), id(v2_save), id(manual_scope),
|
||||
id(manual_scope_v), id(save_template)],
|
||||
map(id, trackable_utils.list_objects(save_template)))
|
||||
manual_dep, = manual_scope._checkpoint_dependencies
|
||||
self.assertEqual("in_manual_scope", manual_dep.name)
|
||||
self.assertIs(manual_scope_v, manual_dep.ref)
|
||||
optimizer = adam.Adam(0.0)
|
||||
save_root = trackable_utils.Checkpoint(
|
||||
my_template=save_template, optimizer=optimizer)
|
||||
optimizer.minimize(v1_save.read_value,
|
||||
var_list=[v1_save])
|
||||
self.evaluate([v.initializer for v in save_template.variables])
|
||||
optimizer_variables = optimizer.variables() + list(
|
||||
optimizer._hyper.values())
|
||||
self.evaluate([v.initializer for v in optimizer_variables])
|
||||
self.evaluate(v1_save.assign([12.]))
|
||||
self.evaluate(v2_save.assign([14.]))
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
save_path = save_root.save(checkpoint_prefix)
|
||||
|
||||
load_template = template.make_template("s2", _templated)
|
||||
load_optimizer = adam.Adam(0.0)
|
||||
load_root = trackable_utils.Checkpoint(
|
||||
my_template=load_template, optimizer=load_optimizer)
|
||||
status = load_root.restore(save_path)
|
||||
var, var_plus_one, var2, _, _ = load_template()
|
||||
load_optimizer.minimize(var.read_value, var_list=[var])
|
||||
self.assertLen(load_template._checkpoint_dependencies, 3)
|
||||
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
|
||||
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
|
||||
self.assertEqual("ManualScope",
|
||||
load_template._checkpoint_dependencies[2].name)
|
||||
status.assert_consumed().run_restore_ops()
|
||||
self.assertAllEqual([12.], self.evaluate(var))
|
||||
self.assertAllEqual([13.], self.evaluate(var_plus_one))
|
||||
self.assertAllEqual([14.], self.evaluate(var2))
|
||||
|
||||
|
||||
class CheckpointCompatibilityTests(test.TestCase):
|
||||
|
||||
def _initialized_model(self):
|
||||
input_value = constant_op.constant([[3.]])
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root_trackable = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
train_op = optimizer.apply_gradients(zip(gradients, variables))
|
||||
self.evaluate(trackable_utils.gather_initializers(
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
# A regular variable, a slot variable, and a non-slot Optimizer variable
|
||||
# with known values to check when loading.
|
||||
self.evaluate(model._named_dense.bias.assign([1.]))
|
||||
self.evaluate(optimizer.get_slot(
|
||||
var=model._named_dense.bias, slot_name="m").assign([2.]))
|
||||
self.evaluate(optimizer.beta_1.assign(3.))
|
||||
return root_trackable
|
||||
|
||||
def _set_sentinels(self, root_trackable):
|
||||
self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
|
||||
self.evaluate(
|
||||
root_trackable.optimizer.get_slot(
|
||||
var=root_trackable.model._named_dense.bias, slot_name="m")
|
||||
.assign([102.]))
|
||||
self.evaluate(root_trackable.optimizer.beta_1.assign(103.))
|
||||
|
||||
def _check_sentinels(self, root_trackable):
|
||||
self.assertAllEqual(
|
||||
[1.], self.evaluate(root_trackable.model._named_dense.bias))
|
||||
self.assertAllEqual([2.], self.evaluate(
|
||||
root_trackable.optimizer.get_slot(
|
||||
var=root_trackable.model._named_dense.bias, slot_name="m")))
|
||||
self.assertAllEqual(3.,
|
||||
self.evaluate(root_trackable.optimizer.beta_1))
|
||||
|
||||
def _write_name_based_checkpoint(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
with context.graph_mode():
|
||||
save_graph = ops.Graph()
|
||||
with save_graph.as_default(), self.session(
|
||||
graph=save_graph) as session:
|
||||
root = self._initialized_model()
|
||||
name_saver = saver_lib.Saver()
|
||||
return name_saver.save(
|
||||
sess=session,
|
||||
save_path=checkpoint_prefix,
|
||||
global_step=root.optimizer.iterations)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLoadFromNameBasedSaver(self):
|
||||
"""Save a name-based checkpoint, load it using the object-based API."""
|
||||
with test_util.device(use_gpu=True):
|
||||
save_path = self._write_name_based_checkpoint()
|
||||
root = self._initialized_model()
|
||||
self._set_sentinels(root)
|
||||
with self.assertRaises(AssertionError):
|
||||
self._check_sentinels(root)
|
||||
object_saver = trackable_utils.TrackableSaver(
|
||||
graph_view.ObjectGraphView(root))
|
||||
self._set_sentinels(root)
|
||||
status = object_saver.restore(save_path)
|
||||
if context.executing_eagerly():
|
||||
self._check_sentinels(root)
|
||||
if context.executing_eagerly():
|
||||
status.assert_consumed()
|
||||
status.assert_existing_objects_matched()
|
||||
status.assert_nontrivial_match()
|
||||
else:
|
||||
# When graph building, we haven't read any keys, so we don't know
|
||||
# whether the restore will be complete.
|
||||
with self.assertRaisesRegexp(AssertionError, "not restored"):
|
||||
status.assert_consumed()
|
||||
with self.assertRaisesRegexp(AssertionError, "not restored"):
|
||||
status.assert_existing_objects_matched()
|
||||
with self.assertRaisesRegexp(AssertionError, "not restored"):
|
||||
status.assert_nontrivial_match()
|
||||
status.run_restore_ops()
|
||||
self._check_sentinels(root)
|
||||
self._set_sentinels(root)
|
||||
status = object_saver.restore(save_path)
|
||||
status.initialize_or_restore()
|
||||
status.assert_nontrivial_match()
|
||||
self._check_sentinels(root)
|
||||
# Check that there is no error when keys are missing from the name-based
|
||||
# checkpoint.
|
||||
root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable([1.])
|
||||
status = object_saver.restore(save_path)
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_existing_objects_matched()
|
||||
|
||||
def testSaveGraphLoadEager(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
with context.graph_mode():
|
||||
save_graph = ops.Graph()
|
||||
with save_graph.as_default(), self.session(
|
||||
graph=save_graph):
|
||||
root = self._initialized_model()
|
||||
save_path = root.save(file_prefix=checkpoint_prefix)
|
||||
with context.eager_mode():
|
||||
root = self._initialized_model()
|
||||
self._set_sentinels(root)
|
||||
root.restore(save_path).assert_consumed()
|
||||
self._check_sentinels(root)
|
||||
|
||||
def testSaveEagerLoadGraph(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
with context.eager_mode():
|
||||
root = self._initialized_model()
|
||||
save_path = root.save(file_prefix=checkpoint_prefix)
|
||||
with context.graph_mode():
|
||||
save_graph = ops.Graph()
|
||||
with save_graph.as_default(), self.session(
|
||||
graph=save_graph):
|
||||
root = self._initialized_model()
|
||||
self._set_sentinels(root)
|
||||
root.restore(save_path).assert_consumed().run_restore_ops()
|
||||
self._check_sentinels(root)
|
||||
|
||||
def testIgnoreSaveCounter(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
with self.cached_session() as session:
|
||||
# Create and save a model using Saver() before using a Checkpoint. This
|
||||
# generates a snapshot without the Checkpoint's `save_counter`.
|
||||
model = sequential.Sequential()
|
||||
model.add(core.Flatten(input_shape=(1,)))
|
||||
model.add(core.Dense(1))
|
||||
name_saver = saver_lib.Saver(model.trainable_variables)
|
||||
save_path = name_saver.save(
|
||||
sess=session, save_path=checkpoint_prefix, global_step=1)
|
||||
# Checkpoint.restore must successfully load that checkpoint.
|
||||
ckpt = trackable_utils.Checkpoint(model=model)
|
||||
status = ckpt.restore(save_path)
|
||||
status.assert_existing_objects_matched()
|
||||
# It should, however, refuse to load a checkpoint where an unrelated
|
||||
# `save_counter` variable is missing.
|
||||
model.layers[1].var = variables_lib.Variable(0., name="save_counter")
|
||||
status = ckpt.restore(save_path)
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_existing_objects_matched()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
@ -98,8 +98,6 @@ tf_py_test(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras:engine",
|
||||
"//tensorflow/python/keras/layers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -149,7 +147,6 @@ py_library(
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/training/saving:checkpoint_options",
|
||||
"//tensorflow/python/training/saving:functional_saver",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
@ -188,10 +185,6 @@ tf_py_test(
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//tensorflow/python/keras:engine",
|
||||
"//tensorflow/python/keras/layers",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
"//tensorflow/python/training/saving:checkpoint_options",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"@six_archive//:six",
|
||||
|
@ -23,26 +23,16 @@ import os
|
||||
import pickle
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy
|
||||
import six
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import normalization
|
||||
from tensorflow.python.layers import core as non_keras_core
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
@ -52,184 +42,13 @@ from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
|
||||
class HasList(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasList, self).__init__()
|
||||
self.layer_list = data_structures.List([core.Dense(3)])
|
||||
self.layer_list.append(core.Dense(4))
|
||||
self.layer_list.extend(
|
||||
[core.Dense(5),
|
||||
core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
|
||||
self.layer_list += [
|
||||
core.Dense(7, bias_regularizer=math_ops.reduce_sum),
|
||||
core.Dense(8)
|
||||
]
|
||||
self.layer_list += (
|
||||
data_structures.List([core.Dense(9)]) + data_structures.List(
|
||||
[core.Dense(10)]))
|
||||
self.layer_list.extend(
|
||||
data_structures.List(
|
||||
list([core.Dense(11)]) + [core.Dense(12)]))
|
||||
self.layers_with_updates = data_structures.List(
|
||||
(normalization.BatchNormalization(),))
|
||||
|
||||
def call(self, x):
|
||||
aggregation = 0.
|
||||
for l in self.layer_list:
|
||||
x = l(x)
|
||||
aggregation += math_ops.reduce_sum(x)
|
||||
bn, = self.layers_with_updates
|
||||
return bn(x) / aggregation
|
||||
|
||||
|
||||
class ListTests(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testTracking(self):
|
||||
model = HasList()
|
||||
output = model(array_ops.ones([32, 2]))
|
||||
self.assertAllEqual([32, 12], output.shape)
|
||||
self.assertEqual(11, len(model.layers))
|
||||
self.assertEqual(10, len(model.layer_list.layers))
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
model.layers,
|
||||
model.layer_list.layers + model.layers_with_updates)
|
||||
for index in range(10):
|
||||
self.assertEqual(3 + index, model.layer_list.layers[index].units)
|
||||
self.assertEqual(2, len(model._checkpoint_dependencies))
|
||||
self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
|
||||
self.assertIs(model.layers_with_updates,
|
||||
model._checkpoint_dependencies[1].ref)
|
||||
self.assertEqual(
|
||||
10, len(model._checkpoint_dependencies[0].ref._checkpoint_dependencies))
|
||||
self.evaluate([v.initializer for v in model.variables])
|
||||
self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]]))
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
|
||||
model.load_weights(save_path)
|
||||
self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
|
||||
self.evaluate(model.variables[0]))
|
||||
v = variables.Variable(1.)
|
||||
model.var_list = [v]
|
||||
self.assertIn(v, model.variables)
|
||||
self.assertIn(v, model.trainable_variables)
|
||||
self.assertNotIn(v, model.non_trainable_variables)
|
||||
self.assertIn(model.layer_list[0].trainable_weights[0],
|
||||
model.trainable_weights)
|
||||
|
||||
def testSubModelTracking(self):
|
||||
model = training.Model()
|
||||
model.v = variables.Variable(1.)
|
||||
self.assertIn(model.v, model.trainable_weights)
|
||||
model2 = training.Model()
|
||||
model2.m = [model]
|
||||
self.assertIn(model.v, model2.trainable_weights)
|
||||
|
||||
def testSubSequentialTracking(self):
|
||||
|
||||
class _Subclassed(training.Model):
|
||||
|
||||
def __init__(self, wrapped):
|
||||
super(_Subclassed, self).__init__()
|
||||
self._wrapped = wrapped
|
||||
|
||||
def call(self, x):
|
||||
return self._wrapped(x)
|
||||
|
||||
model = sequential.Sequential()
|
||||
layer = core.Dense(1)
|
||||
model.add(layer)
|
||||
model2 = _Subclassed(model)
|
||||
model2(array_ops.ones([1, 2]))
|
||||
model2.m = [model]
|
||||
self.assertIn(layer.kernel, model2.trainable_weights)
|
||||
|
||||
def testLayerTrackedThroughSequential(self):
|
||||
class AttrDict(dict):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
def ffnet(layer_sizes, name):
|
||||
ff = sequential.Sequential(name=name)
|
||||
for i, width in enumerate(layer_sizes):
|
||||
ff.add(core.Dense(
|
||||
width,
|
||||
activation=("relu" if i < len(layer_sizes)-1 else None)))
|
||||
return ff
|
||||
|
||||
class MyModel2(training.Model):
|
||||
|
||||
def __init__(self, config, name="my_model_2"):
|
||||
super(MyModel2, self).__init__(name=name)
|
||||
self._num_tokens = config.num_tokens
|
||||
|
||||
# list of sub-models
|
||||
self._ffnet = [ffnet(config.module_layers + (self._num_tokens,), "ff")]
|
||||
|
||||
def null_input(self):
|
||||
return array_ops.zeros([1, self._num_tokens], dtype=dtypes.float32)
|
||||
|
||||
def call(self, input_, module_index=None):
|
||||
return self._ffnet[0](input_)
|
||||
|
||||
m2 = MyModel2(AttrDict(
|
||||
num_tokens=5,
|
||||
module_layers=(50, 30)))
|
||||
|
||||
# Construct
|
||||
m2(m2.null_input())
|
||||
self.assertLen(m2.trainable_variables, 6)
|
||||
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.l = [1]
|
||||
json.dumps(obj.l, default=serialization.get_json_type)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testUpdatesForwarded(self):
|
||||
with context.graph_mode():
|
||||
model = HasList()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertGreater(len(model.layers_with_updates[0].updates), 0)
|
||||
self.assertEqual(set(model.layers_with_updates[0].updates),
|
||||
set(model.updates))
|
||||
|
||||
with context.eager_mode():
|
||||
model = HasList()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertEqual(0, len(model.updates))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testLossesForwarded(self):
|
||||
model = HasList()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertEqual(2, len(model.losses))
|
||||
|
||||
def testModelContainersCompareEqual(self):
|
||||
class HasEqualContainers(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasEqualContainers, self).__init__()
|
||||
self.l1 = []
|
||||
self.l2 = []
|
||||
|
||||
model = HasEqualContainers()
|
||||
first_layer = HasEqualContainers()
|
||||
model.l1.append(first_layer)
|
||||
second_layer = HasEqualContainers()
|
||||
model.l2.append(second_layer)
|
||||
self.assertEqual([first_layer, second_layer], model.layers)
|
||||
|
||||
def testNotTrackable(self):
|
||||
class NotTrackable(object):
|
||||
pass
|
||||
@ -245,23 +64,6 @@ class ListTests(test.TestCase):
|
||||
with self.assertRaises(AttributeError):
|
||||
data_structures.List().pop()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testTensorConversion(self):
|
||||
|
||||
class ListToTensor(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(ListToTensor, self).__init__()
|
||||
self.l = [1., 2., 3.]
|
||||
|
||||
self.assertAllEqual(
|
||||
[1., 2., 3.],
|
||||
self.evaluate(constant_op.constant(ListToTensor().l)))
|
||||
|
||||
self.assertAllEqual(
|
||||
[1., 2., 3.],
|
||||
self.evaluate(array_ops.pack(ListToTensor().l)))
|
||||
|
||||
def testNesting(self):
|
||||
with context.graph_mode():
|
||||
inner = data_structures.List()
|
||||
@ -315,8 +117,7 @@ class ListTests(test.TestCase):
|
||||
self.assertEqual(l[:-1], [v1, v2, v3])
|
||||
|
||||
def testHash(self):
|
||||
has_sequences = set([data_structures.List(),
|
||||
data_structures.List()])
|
||||
has_sequences = {data_structures.List(), data_structures.List()}
|
||||
self.assertEqual(2, len(has_sequences))
|
||||
self.assertNotIn(data_structures.List(), has_sequences)
|
||||
|
||||
@ -454,13 +255,6 @@ class ListWrapperTest(test.TestCase):
|
||||
l.append(1)
|
||||
self.assertEqual([1], l_wrapper)
|
||||
|
||||
def testLayerCollectionWithExternalMutation(self):
|
||||
l = []
|
||||
l_wrapper = data_structures.ListWrapper(l)
|
||||
layer = core.Dense(1)
|
||||
l.append(layer)
|
||||
self.assertEqual([layer], l_wrapper.layers)
|
||||
|
||||
def testNotHashable(self):
|
||||
with self.assertRaises(TypeError):
|
||||
hash(data_structures.ListWrapper())
|
||||
@ -538,50 +332,8 @@ class ListWrapperTest(test.TestCase):
|
||||
return l._checkpoint_dependencies # pylint: disable=protected-access
|
||||
|
||||
|
||||
class HasMapping(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasMapping, self).__init__()
|
||||
self.layer_dict = data_structures.Mapping(output=core.Dense(7))
|
||||
self.layer_dict["norm"] = data_structures.List()
|
||||
self.layer_dict["dense"] = data_structures.List()
|
||||
self.layer_dict["dense"].extend(
|
||||
[core.Dense(5),
|
||||
core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
|
||||
self.layer_dict["norm"].append(
|
||||
normalization.BatchNormalization())
|
||||
self.layer_dict["norm"].append(
|
||||
normalization.BatchNormalization())
|
||||
|
||||
def call(self, x):
|
||||
aggregation = 0.
|
||||
for norm, dense in zip(self.layer_dict["norm"], self.layer_dict["dense"]):
|
||||
x = norm(dense(x))
|
||||
aggregation += math_ops.reduce_sum(x)
|
||||
return self.layer_dict["output"](x) / aggregation
|
||||
|
||||
|
||||
class MappingTests(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testTracking(self):
|
||||
model = HasMapping()
|
||||
output = model(array_ops.ones([32, 2]))
|
||||
self.assertAllEqual([32, 7], output.shape.as_list())
|
||||
self.assertEqual(5, len(model.layers))
|
||||
six.assertCountEqual(self, model.layers, model.layer_dict.layers)
|
||||
self.assertEqual(1, len(model._checkpoint_dependencies))
|
||||
self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref)
|
||||
self.evaluate([v.initializer for v in model.variables])
|
||||
test_var = model.layer_dict["output"].kernel
|
||||
self.evaluate(test_var.assign(array_ops.ones([6, 7])))
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
self.evaluate(test_var.assign(array_ops.zeros([6, 7])))
|
||||
model.load_weights(save_path)
|
||||
self.assertAllEqual(numpy.ones([6, 7]),
|
||||
self.evaluate(test_var))
|
||||
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.d = {"a": 2}
|
||||
@ -605,20 +357,6 @@ class MappingTests(test.TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
mapping[1] = data_structures.List()
|
||||
|
||||
def testLayerCollectionWithExternalMutation(self):
|
||||
d = {}
|
||||
root = tracking.AutoTrackable()
|
||||
root.wrapper = d
|
||||
self.assertEqual([], root.wrapper.layers)
|
||||
self.assertEqual([], root.wrapper.trainable_weights)
|
||||
layer1 = core.Dense(1)
|
||||
layer2 = core.Dense(1)
|
||||
d["a"] = layer1
|
||||
d["b"] = layer2
|
||||
self.assertEqual([layer1, layer2], root.wrapper.layers)
|
||||
# The layers have still not created variables
|
||||
self.assertEqual([], root.wrapper.trainable_weights)
|
||||
|
||||
def testHashing(self):
|
||||
has_mappings = set([data_structures.Mapping(),
|
||||
data_structures.Mapping()])
|
||||
@ -633,101 +371,6 @@ class MappingTests(test.TestCase):
|
||||
with self.assertRaisesRegexp(TypeError, "unhashable"):
|
||||
set([a.d])
|
||||
|
||||
def testDictWrapperBadKeys(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
a.d[1] = data_structures.List()
|
||||
model = training.Model()
|
||||
model.sub = a
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.assertRaisesRegexp(ValueError, "non-string key"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testDictWrapperNoDependency(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = data_structures.NoDependency({})
|
||||
a.d[1] = [3]
|
||||
self.assertEqual([a], util.list_objects(a))
|
||||
model = training.Model()
|
||||
model.sub = a
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testNonStringKeyNotTrackableValue(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
a.d["a"] = [3]
|
||||
a.d[1] = data_structures.NoDependency([3])
|
||||
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
||||
model = training.Model()
|
||||
model.sub = a
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testNonAppendNotTrackable(self):
|
||||
# Non-append mutations (deleting or overwriting values) are OK when the
|
||||
# values aren't tracked.
|
||||
a = tracking.AutoTrackable()
|
||||
a.d = {}
|
||||
a.d["a"] = [3]
|
||||
a.d[1] = 3
|
||||
a.d[1] = 2
|
||||
self.assertEqual(2, a.d[1])
|
||||
del a.d[1]
|
||||
a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
|
||||
second = tracking.AutoTrackable()
|
||||
a.d[2] = data_structures.NoDependency(second)
|
||||
self.assertIs(second, a.d[2])
|
||||
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
||||
model = training.Model()
|
||||
model.sub = a
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testPopNoSave(self):
|
||||
model = training.Model()
|
||||
model.d = {}
|
||||
model.d["a"] = []
|
||||
model.d.pop("a")
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.assertRaisesRegexp(ValueError, "Unable to save"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testExternalModificationNoSave(self):
|
||||
model = training.Model()
|
||||
external_reference = {}
|
||||
model.d = external_reference
|
||||
external_reference["a"] = []
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testOverwriteCanStillSave(self):
|
||||
model = training.Model()
|
||||
model.d = {}
|
||||
model.d["a"] = {}
|
||||
model.d["a"] = {}
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testIter(self):
|
||||
model = training.Model()
|
||||
model.d = {1: 3}
|
||||
model.d[1] = 3
|
||||
self.assertEqual([1], list(model.d))
|
||||
new_dict = {}
|
||||
# This update() is super tricky. If the dict wrapper subclasses dict,
|
||||
# CPython will access its storage directly instead of calling any
|
||||
# methods/properties on the object. So the options are either not to
|
||||
# subclass dict (in which case update will call normal iter methods, but the
|
||||
# object won't pass isinstance checks) or to subclass dict and keep that
|
||||
# storage updated (no shadowing all its methods like ListWrapper).
|
||||
new_dict.update(model.d)
|
||||
self.assertEqual({1: 3}, new_dict)
|
||||
|
||||
def testListShallowCopy(self):
|
||||
root = tracking.AutoTrackable()
|
||||
orig_list = [[1.]]
|
||||
@ -871,157 +514,13 @@ class MappingTests(test.TestCase):
|
||||
self.assertIs(first_trace, second_trace)
|
||||
|
||||
|
||||
class HasTuple(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasTuple, self).__init__()
|
||||
self.layer_list = (
|
||||
core.Dense(3), core.Dense(4),
|
||||
core.Dense(5, kernel_regularizer=math_ops.reduce_sum))
|
||||
self.layers_with_updates = (normalization.BatchNormalization(),)
|
||||
|
||||
def call(self, x):
|
||||
aggregation = 0.
|
||||
for l in self.layer_list:
|
||||
x = l(x)
|
||||
aggregation += math_ops.reduce_sum(x)
|
||||
bn, = self.layers_with_updates
|
||||
return bn(x) / aggregation
|
||||
|
||||
|
||||
class TupleTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testTracking(self):
|
||||
model = HasTuple()
|
||||
output = model(array_ops.ones([32, 2]))
|
||||
self.assertAllEqual([32, 5], output.shape.as_list())
|
||||
self.assertLen(model.layers, 4)
|
||||
self.assertLen(model.layer_list.layers, 3)
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
model.layers,
|
||||
tuple(model.layer_list.layers) + model.layers_with_updates)
|
||||
self.assertEqual(3, model.layer_list.layers[0].units)
|
||||
self.assertEqual(4, model.layer_list.layers[1].units)
|
||||
self.assertEqual(5, model.layer_list.layers[2].units)
|
||||
self.assertLen(model._checkpoint_dependencies, 2)
|
||||
self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
|
||||
self.assertIs(model.layers_with_updates,
|
||||
model._checkpoint_dependencies[1].ref)
|
||||
self.assertLen(
|
||||
model._checkpoint_dependencies[0].ref._checkpoint_dependencies, 3)
|
||||
self.evaluate([v.initializer for v in model.variables])
|
||||
self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]]))
|
||||
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
model.save_weights(save_path)
|
||||
self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
|
||||
model.load_weights(save_path)
|
||||
self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
|
||||
self.evaluate(model.variables[0]))
|
||||
v = variables.Variable(1.)
|
||||
model.var_list = (v,)
|
||||
self.assertIn(id(v), [id(obj) for obj in model.variables])
|
||||
self.assertIn(id(v), [id(obj) for obj in model.trainable_variables])
|
||||
self.assertNotIn(id(v), [id(obj) for obj in model.non_trainable_variables])
|
||||
self.assertIn(id(model.layer_list[0].trainable_weights[0]),
|
||||
[id(obj) for obj in model.trainable_weights])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Module", module.Module),
|
||||
("Model", training.Model),
|
||||
)
|
||||
def testSubModelTracking(self, module_subclass):
|
||||
model = module_subclass()
|
||||
model.v = variables.Variable(1.)
|
||||
self.assertIn(model.v, model.trainable_variables)
|
||||
model2 = module_subclass()
|
||||
model2.m = (model,)
|
||||
self.assertIn(model.v, model2.trainable_variables)
|
||||
|
||||
def testSubSequentialTracking(self):
|
||||
|
||||
class _Subclassed(training.Model):
|
||||
|
||||
def __init__(self, wrapped):
|
||||
super(_Subclassed, self).__init__()
|
||||
self._wrapped = wrapped
|
||||
|
||||
def call(self, x):
|
||||
return self._wrapped(x)
|
||||
|
||||
model = sequential.Sequential()
|
||||
layer = core.Dense(1)
|
||||
model.add(layer)
|
||||
model2 = _Subclassed(model)
|
||||
model2(array_ops.ones([1, 2]))
|
||||
model2.m = (model,)
|
||||
self.assertIn(layer.kernel, model2.trainable_weights)
|
||||
|
||||
def testJSONSerialization(self):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.l = (1,)
|
||||
json.dumps(obj.l, default=serialization.get_json_type)
|
||||
|
||||
def testUpdatesForwarded(self):
|
||||
with ops.Graph().as_default():
|
||||
model = HasTuple()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertNotEmpty(model.layers_with_updates[0].updates)
|
||||
self.assertEqual(set(model.layers_with_updates[0].updates),
|
||||
set(model.updates))
|
||||
|
||||
model = HasTuple()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertEmpty(model.updates)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLossesForwarded(self):
|
||||
model = HasTuple()
|
||||
model_input = array_ops.ones([32, 2])
|
||||
model(model_input)
|
||||
self.assertLen(model.losses, 1)
|
||||
|
||||
def testModelContainersCompareEqual(self):
|
||||
class HasEqualContainers(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(HasEqualContainers, self).__init__()
|
||||
self.l1 = ()
|
||||
self.l2 = ()
|
||||
|
||||
model = HasEqualContainers()
|
||||
first_layer = HasEqualContainers()
|
||||
model.l1 = (first_layer,)
|
||||
second_layer = HasEqualContainers()
|
||||
model.l2 = (second_layer,)
|
||||
self.assertEqual((first_layer,), model.l1)
|
||||
d = {model.l1: 1, model.l2: 2}
|
||||
self.assertEqual(1, d[model.l1])
|
||||
self.assertEqual(1, d[(first_layer,)])
|
||||
self.assertEqual(2, d[model.l2])
|
||||
self.assertEqual(2, d[(second_layer,)])
|
||||
self.assertEqual([first_layer, second_layer], model.layers)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testTensorConversion(self):
|
||||
|
||||
class TupleToTensor(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(TupleToTensor, self).__init__()
|
||||
self.l = (1., 2., 3.)
|
||||
|
||||
self.assertAllEqual(
|
||||
(1., 2., 3.),
|
||||
self.evaluate(constant_op.constant(TupleToTensor().l)))
|
||||
|
||||
self.assertAllEqual(
|
||||
(1., 2., 3.),
|
||||
self.evaluate(array_ops.pack(TupleToTensor().l)))
|
||||
|
||||
def testNonLayerVariables(self):
|
||||
v = resource_variable_ops.ResourceVariable([1.])
|
||||
l = data_structures._TupleWrapper((v,))
|
||||
|
@ -25,14 +25,10 @@ import time
|
||||
import timeit
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
@ -73,28 +69,6 @@ class InterfaceTests(test.TestCase):
|
||||
(_, dep_object), = root._checkpoint_dependencies
|
||||
self.assertIs(duplicate_name_dep, dep_object)
|
||||
|
||||
def testNoDependency(self):
|
||||
root = tracking.AutoTrackable()
|
||||
hasdep = tracking.AutoTrackable()
|
||||
root.hasdep = hasdep
|
||||
nodep = tracking.AutoTrackable()
|
||||
root.nodep = data_structures.NoDependency(nodep)
|
||||
self.assertEqual(1, len(root._checkpoint_dependencies))
|
||||
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
|
||||
self.assertIs(root.hasdep, hasdep)
|
||||
self.assertIs(root.nodep, nodep)
|
||||
|
||||
class NoDependencyModel(training.Model):
|
||||
|
||||
@base.no_automatic_dependency_tracking
|
||||
def __init__(self):
|
||||
super(NoDependencyModel, self).__init__()
|
||||
self.a = []
|
||||
self.b = tracking.AutoTrackable()
|
||||
|
||||
nodeps = NoDependencyModel()
|
||||
self.assertEqual([nodeps], util.list_objects(nodeps))
|
||||
|
||||
def testRemoveDependency(self):
|
||||
root = tracking.AutoTrackable()
|
||||
root.a = tracking.AutoTrackable()
|
||||
@ -183,43 +157,6 @@ class InterfaceTests(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
|
||||
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDictionariesBasic(self):
|
||||
a = training.Model()
|
||||
b = training.Model()
|
||||
a.attribute = {"b": b}
|
||||
c = training.Model()
|
||||
a.attribute["c"] = []
|
||||
a.attribute["c"].append(c)
|
||||
a_deps = util.list_objects(a)
|
||||
self.assertIn(b, a_deps)
|
||||
self.assertIn(c, a_deps)
|
||||
self.assertIs(b, a.attribute["b"])
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
["b", "c"],
|
||||
[dep.name for dep in a.attribute._checkpoint_dependencies])
|
||||
self.assertEqual([b, c], a.layers)
|
||||
self.assertEqual([b, c], a.attribute.layers)
|
||||
self.assertEqual([c], a.attribute["c"].layers)
|
||||
checkpoint = util.Checkpoint(a=a)
|
||||
save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
with self.cached_session():
|
||||
checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testNoDepList(self):
|
||||
a = training.Model()
|
||||
a.l1 = data_structures.NoDependency([])
|
||||
a.l1.insert(1, 0)
|
||||
self.assertTrue(isinstance(a.l1, list))
|
||||
checkpoint = util.Checkpoint(a=a)
|
||||
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
a.l2 = []
|
||||
a.l2.insert(1, module.Module())
|
||||
with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
|
||||
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testAssertions(self):
|
||||
a = tracking.AutoTrackable()
|
||||
|
@ -16,25 +16,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
import weakref
|
||||
|
||||
from absl.testing import parameterized
|
||||
import six
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -46,7 +39,6 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.saving import checkpoint_options
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
@ -62,44 +54,8 @@ class NonLayerTrackable(tracking.AutoTrackable):
|
||||
self, name="a_variable", shape=[])
|
||||
|
||||
|
||||
# pylint: disable=not-callable
|
||||
class MyModel(training.Model):
|
||||
"""A concrete Model for testing."""
|
||||
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self._named_dense = core.Dense(1, use_bias=True)
|
||||
self._second = core.Dense(1, use_bias=False)
|
||||
# We can still track Trackables which aren't Layers.
|
||||
self._non_layer = NonLayerTrackable()
|
||||
|
||||
def call(self, values):
|
||||
ret = self._second(self._named_dense(values))
|
||||
return ret
|
||||
|
||||
|
||||
class InterfaceTests(test.TestCase):
|
||||
|
||||
def testLayerDeduplication(self):
|
||||
model = training.Model()
|
||||
layer_one = core.Dense(1)
|
||||
layer_two = core.Dense(1)
|
||||
model.other_path = [layer_one, layer_two]
|
||||
model.l2 = layer_two
|
||||
model.l1 = layer_one
|
||||
self.assertEqual([layer_one, layer_two], model.layers)
|
||||
|
||||
def testSaveWithOnlyKerasSession(self):
|
||||
|
||||
with ops.Graph().as_default():
|
||||
inp = input_layer.Input([1])
|
||||
dense = core.Dense(1)(inp)
|
||||
model = training.Model(inp, dense)
|
||||
model.compile(optimizer="sgd", loss="mse")
|
||||
model.fit([1.], [2.])
|
||||
checkpoint = trackable_utils.Checkpoint(model=model)
|
||||
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testAddVariable(self):
|
||||
obj = NonLayerTrackable()
|
||||
@ -184,22 +140,6 @@ class InterfaceTests(test.TestCase):
|
||||
self.assertEqual(dtypes.float64, v2.dtype)
|
||||
self.assertAllEqual([1., 1., 1.], self.evaluate(v2))
|
||||
|
||||
def testObjectMetadata(self):
|
||||
with context.eager_mode():
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
dense = core.Dense(1)
|
||||
checkpoint = trackable_utils.Checkpoint(dense=dense)
|
||||
dense(constant_op.constant([[1.]]))
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
|
||||
objects = trackable_utils.object_metadata(save_path)
|
||||
all_variable_names = []
|
||||
for obj in objects.nodes:
|
||||
for attribute in obj.attributes:
|
||||
all_variable_names.append(attribute.full_name)
|
||||
self.assertIn("dense/kernel", all_variable_names)
|
||||
|
||||
def testNotTrackable(self):
|
||||
|
||||
class CallsFunctionalStuff(
|
||||
@ -268,100 +208,6 @@ class _OwnsMirroredVariables(base.Trackable):
|
||||
|
||||
class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testNamingWithOptimizer(self):
|
||||
input_value = constant_op.constant([[3.]])
|
||||
model = MyModel()
|
||||
# A nuisance Model using the same optimizer. Its slot variables should not
|
||||
# go in the checkpoint, since it is never depended on.
|
||||
other_model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
step = training_util.get_or_create_global_step()
|
||||
root_trackable = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model, step=step)
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
train_op = control_flow_ops.group(
|
||||
optimizer.apply_gradients(zip(gradients, variables)),
|
||||
step.assign_add(1))
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = other_model(input_value)
|
||||
variables = other_model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
optimizer.apply_gradients(zip(gradients, variables))
|
||||
|
||||
self.evaluate(trackable_utils.gather_initializers(
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
|
||||
root_trackable).serialize_object_graph()
|
||||
expected_slot_keys = (
|
||||
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
|
||||
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
|
||||
"model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
|
||||
"model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
|
||||
"model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
|
||||
"model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
|
||||
)
|
||||
expected_checkpoint_names = (
|
||||
# Created in the root node, so no prefix.
|
||||
"step",
|
||||
"model/_second/kernel",
|
||||
"model/_named_dense/kernel",
|
||||
"model/_named_dense/bias",
|
||||
# non-Layer dependency of the model
|
||||
"model/_non_layer/a_variable",
|
||||
"optimizer/learning_rate",
|
||||
"optimizer/beta_1",
|
||||
"optimizer/beta_2",
|
||||
"optimizer/iter",
|
||||
"optimizer/decay",
|
||||
) + expected_slot_keys
|
||||
suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
expected_checkpoint_names = [
|
||||
name + suffix for name in expected_checkpoint_names]
|
||||
named_variables = {v.name: v for v in named_variables}
|
||||
six.assertCountEqual(self, expected_checkpoint_names,
|
||||
named_variables.keys())
|
||||
# Check that we've mapped to the right variable objects (not exhaustive)
|
||||
self.assertEqual(
|
||||
"global_step",
|
||||
named_variables["step" + suffix].full_name)
|
||||
self.assertEqual(
|
||||
"my_model/dense_1/kernel",
|
||||
named_variables["model/_second/kernel" + suffix].full_name)
|
||||
self.assertEqual(
|
||||
"my_model/dense/kernel",
|
||||
named_variables["model/_named_dense/kernel" + suffix].full_name)
|
||||
self.assertEqual("Adam/beta_1",
|
||||
named_variables["optimizer/beta_1" + suffix].full_name)
|
||||
self.assertEqual("Adam/beta_2",
|
||||
named_variables["optimizer/beta_2" + suffix].full_name)
|
||||
# Spot check the generated protocol buffers.
|
||||
self.assertEqual("optimizer",
|
||||
serialized_graph.nodes[0].children[1].local_name)
|
||||
optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
|
||||
1].node_id]
|
||||
children = [node.local_name for node in optimizer_node.children]
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
# hyper variable dependencies
|
||||
["beta_1", "beta_2", "iter", "decay", "learning_rate"],
|
||||
children)
|
||||
serialized_slot_keys = []
|
||||
for slot in optimizer_node.slot_variables:
|
||||
for attribute in (
|
||||
serialized_graph.nodes[slot.slot_variable_node_id].attributes):
|
||||
serialized_slot_keys.append(attribute.checkpoint_key)
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
[key + suffix for key in expected_slot_keys],
|
||||
serialized_slot_keys)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testMoreComplexSaveableReturned(self):
|
||||
v = _OwnsMirroredVariables()
|
||||
@ -432,174 +278,6 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
if op.type in ("SaveV2", "RestoreV2"):
|
||||
self.assertEqual(localhost, op.device)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSaveRestore(self):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root_trackable = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
train_op = optimizer.apply_gradients(zip(gradients, variables))
|
||||
self.assertFalse(root_trackable.save_counter.trainable)
|
||||
self.evaluate(trackable_utils.gather_initializers(
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
|
||||
m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
|
||||
self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
|
||||
save_path = root_trackable.save(file_prefix=prefix)
|
||||
self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
|
||||
self.evaluate(state_ops.assign(root_trackable.save_counter, 3))
|
||||
optimizer_variables = self.evaluate(
|
||||
sorted(optimizer.variables(), key=lambda v: v.name))
|
||||
self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
|
||||
# Immediate restoration
|
||||
status = root_trackable.restore(save_path=save_path).assert_consumed()
|
||||
status.run_restore_ops()
|
||||
self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
|
||||
self.assertAllEqual(1, self.evaluate(root_trackable.save_counter))
|
||||
self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
|
||||
if not context.executing_eagerly():
|
||||
return # Restore-on-create is only supported when executing eagerly
|
||||
on_create_model = MyModel()
|
||||
on_create_optimizer = adam.Adam(0.001)
|
||||
on_create_root = trackable_utils.Checkpoint(
|
||||
optimizer=on_create_optimizer, model=on_create_model)
|
||||
# Deferred restoration
|
||||
status = on_create_root.restore(save_path=save_path)
|
||||
status.assert_nontrivial_match()
|
||||
status.assert_existing_objects_matched()
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_consumed()
|
||||
on_create_model(constant_op.constant([[3.]])) # create variables
|
||||
self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
|
||||
self.assertAllEqual([42.],
|
||||
self.evaluate(
|
||||
on_create_model._named_dense.variables[1]))
|
||||
on_create_m_bias_slot = on_create_optimizer.get_slot(
|
||||
on_create_model._named_dense.variables[1], "m")
|
||||
status.assert_existing_objects_matched()
|
||||
if not context.executing_eagerly():
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_consumed()
|
||||
# Optimizer slot variables are created when the original variable is
|
||||
# restored.
|
||||
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
|
||||
dummy_var = resource_variable_ops.ResourceVariable([1.])
|
||||
on_create_optimizer.minimize(loss=dummy_var.read_value,
|
||||
var_list=[dummy_var])
|
||||
status.assert_existing_objects_matched()
|
||||
status.assert_consumed()
|
||||
self.assertAllEqual(
|
||||
optimizer_variables,
|
||||
# Creation order is different, so .variables() needs to be re-sorted.
|
||||
self.evaluate(sorted(optimizer.variables(), key=lambda v: v.name)))
|
||||
|
||||
# TODO(allenl): Debug garbage created by this test in python3.
|
||||
def testDeferredRestorationUsageEager(self):
|
||||
"""An idiomatic eager execution example."""
|
||||
num_training_steps = 10
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
for training_continuation in range(3):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
root.restore(checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory))
|
||||
for _ in range(num_training_steps):
|
||||
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
|
||||
input_value = constant_op.constant([[3.]])
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
optimizer.apply_gradients(zip(gradients, variables))
|
||||
root.save(file_prefix=checkpoint_prefix)
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
root.optimizer.iterations.numpy())
|
||||
|
||||
def testUsageGraph(self):
|
||||
"""Expected usage when graph building."""
|
||||
with context.graph_mode():
|
||||
num_training_steps = 10
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
for training_continuation in range(3):
|
||||
with ops.Graph().as_default():
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.CheckpointV1(
|
||||
optimizer=optimizer, model=model)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
train_op = optimizer.apply_gradients(zip(gradients, variables))
|
||||
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
with self.session(graph=ops.get_default_graph()) as session:
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
status.initialize_or_restore(session=session)
|
||||
if checkpoint_path is None:
|
||||
self.assertEqual(0, training_continuation)
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_consumed()
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_existing_objects_matched()
|
||||
else:
|
||||
status.assert_consumed()
|
||||
status.assert_existing_objects_matched()
|
||||
for _ in range(num_training_steps):
|
||||
session.run(train_op)
|
||||
root.save(file_prefix=checkpoint_prefix, session=session)
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
session.run(root.optimizer.iterations))
|
||||
self.assertEqual(training_continuation + 1,
|
||||
session.run(root.save_counter))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testAgnosticUsage(self):
|
||||
"""Graph/eager agnostic usage."""
|
||||
# Does create garbage when executing eagerly due to ops.Graph() creation.
|
||||
num_training_steps = 10
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
def _train_fn(model, input_value):
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
return optimizer.apply_gradients(zip(gradients, variables))
|
||||
for training_continuation in range(3):
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
root, checkpoint_directory, max_to_keep=1)
|
||||
status = root.restore(save_path=manager.latest_checkpoint)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
train_fn = functools.partial(_train_fn, model, input_value)
|
||||
if not context.executing_eagerly():
|
||||
train_fn = functools.partial(self.evaluate, train_fn())
|
||||
status.initialize_or_restore()
|
||||
for _ in range(num_training_steps):
|
||||
train_fn()
|
||||
manager.save()
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
self.evaluate(root.optimizer.iterations))
|
||||
self.assertEqual(training_continuation + 1,
|
||||
self.evaluate(root.save_counter))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testFreezing(self):
|
||||
with test_util.use_gpu():
|
||||
@ -656,31 +334,6 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
self.fail("%s should have suffix %s" % (path, expected_suffix))
|
||||
self.evaluate(step.assign_add(2))
|
||||
|
||||
def testPartialRestoreWarningObject(self):
|
||||
with context.eager_mode():
|
||||
optimizer = adam.Adam(0.0)
|
||||
original_root = trackable_utils.Checkpoint(v1=variables_lib.Variable(2.),
|
||||
v2=variables_lib.Variable(3.),
|
||||
optimizer=optimizer)
|
||||
# Create a slot variable to save
|
||||
optimizer.minimize(original_root.v1.read_value, [original_root.v1])
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
save_path = original_root.save(prefix)
|
||||
partial_root = trackable_utils.Checkpoint(v1=variables_lib.Variable(0.))
|
||||
weak_partial_root = weakref.ref(partial_root)
|
||||
weak_v1 = weakref.ref(partial_root.v1)
|
||||
partial_root.restore(save_path)
|
||||
self.assertEqual(2., partial_root.v1.numpy())
|
||||
with test.mock.patch.object(logging, "warning") as mock_log:
|
||||
del partial_root
|
||||
self.assertIsNone(weak_partial_root())
|
||||
self.assertIsNone(weak_v1())
|
||||
messages = str(mock_log.call_args_list)
|
||||
self.assertIn("(root).v2'", messages)
|
||||
self.assertIn("(root).optimizer's state 'm' for (root).v1", messages)
|
||||
self.assertNotIn("(root).v1'", messages)
|
||||
self.assertIn("expect_partial()", messages)
|
||||
|
||||
def testPartialRestoreWarningAttribute(self):
|
||||
with context.eager_mode():
|
||||
original_root = trackable_utils.Checkpoint(v1=variables_lib.Variable(2.),
|
||||
@ -734,49 +387,6 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
self.assertIsNone(weak_v1())
|
||||
self.assertEmpty(mock_log.call_args_list)
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testWithDefun(self):
|
||||
num_training_steps = 2
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
for training_continuation in range(3):
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
# Don't actually train so we can test variable values
|
||||
optimizer = adam.Adam(0.)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
def train_fn():
|
||||
@def_function.function
|
||||
def _call_model(x):
|
||||
return model(x)
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = _call_model(constant_op.constant([[3.]]))
|
||||
gradients = tape.gradient(loss, model.variables)
|
||||
return optimizer.apply_gradients(zip(gradients, model.variables))
|
||||
if not context.executing_eagerly():
|
||||
train_fn = functools.partial(
|
||||
self.evaluate, train_fn())
|
||||
status.initialize_or_restore()
|
||||
for _ in range(num_training_steps):
|
||||
train_fn()
|
||||
if training_continuation > 0:
|
||||
status.assert_consumed()
|
||||
self.assertAllClose([[42.]], self.evaluate(model.variables[0]))
|
||||
else:
|
||||
self.evaluate(model.variables[0].assign([[42.]]))
|
||||
root.save(file_prefix=checkpoint_prefix)
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
self.evaluate(optimizer.iterations))
|
||||
self.assertEqual(training_continuation + 1,
|
||||
self.evaluate(root.save_counter))
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
def _get_checkpoint_name(self, name):
|
||||
root = tracking.AutoTrackable()
|
||||
trackable_utils.add_variable(
|
||||
@ -819,35 +429,6 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
self.assertEqual("..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE",
|
||||
named_variable.name)
|
||||
|
||||
def testAnonymousVarsInInit(self):
|
||||
|
||||
class Model(training.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.w = resource_variable_ops.ResourceVariable(0.0)
|
||||
self.b = resource_variable_ops.ResourceVariable(0.0)
|
||||
self.vars = [self.w, self.b]
|
||||
|
||||
def call(self, x):
|
||||
return x * self.w + self.b
|
||||
|
||||
with context.eager_mode():
|
||||
model = Model()
|
||||
optimizer = adam.Adam(learning_rate=0.05)
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
checkpoint = trackable_utils.Checkpoint(
|
||||
model=model, optimizer=optimizer)
|
||||
for _ in range(2):
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = (constant_op.constant(1.)
|
||||
- model(constant_op.constant(1.))) ** 2
|
||||
grad = tape.gradient(loss, model.vars)
|
||||
optimizer.apply_gradients(
|
||||
[(g, v) for g, v in zip(grad, model.vars)])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLateDependencyTracking(self):
|
||||
|
||||
@ -909,72 +490,6 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
status.run_restore_ops()
|
||||
self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDeferredSlotRestoration(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
|
||||
root = trackable_utils.Checkpoint()
|
||||
root.var = trackable_utils.add_variable(
|
||||
root, name="var", initializer=0.)
|
||||
optimizer = adam.Adam(0.1)
|
||||
variables = [root.var]
|
||||
gradients = [1.]
|
||||
train_op = optimizer.apply_gradients(zip(gradients, variables))
|
||||
# Note that `optimizer` has not been added as a dependency of
|
||||
# `root`. Create a one-off grouping so that slot variables for `root.var`
|
||||
# get initialized too.
|
||||
self.evaluate(trackable_utils.gather_initializers(
|
||||
trackable_utils.Checkpoint(root=root, optimizer=optimizer)))
|
||||
self.evaluate(train_op)
|
||||
self.evaluate(state_ops.assign(root.var, 12.))
|
||||
no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
|
||||
root.optimizer = optimizer
|
||||
self.evaluate(state_ops.assign(root.var, 13.))
|
||||
self.evaluate(state_ops.assign(
|
||||
optimizer.get_slot(slot_name="m", var=root.var),
|
||||
14.))
|
||||
slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
|
||||
new_root = trackable_utils.Checkpoint()
|
||||
# Load the slot-containing checkpoint (deferred), then immediately overwrite
|
||||
# the non-slot variable (also deferred).
|
||||
slot_status = new_root.restore(slots_path)
|
||||
no_slot_status = new_root.restore(no_slots_path)
|
||||
with self.assertRaises(AssertionError):
|
||||
no_slot_status.assert_consumed()
|
||||
new_root.var = trackable_utils.add_variable(
|
||||
new_root, name="var", shape=[])
|
||||
no_slot_status.assert_consumed()
|
||||
no_slot_status.run_restore_ops()
|
||||
self.assertEqual(12., self.evaluate(new_root.var))
|
||||
new_root.optimizer = adam.Adam(0.1)
|
||||
slot_status.assert_existing_objects_matched()
|
||||
if not context.executing_eagerly():
|
||||
with self.assertRaisesRegexp(AssertionError, "Unresolved object"):
|
||||
slot_status.assert_consumed()
|
||||
self.assertEqual(12., self.evaluate(new_root.var))
|
||||
if context.executing_eagerly():
|
||||
# Slot variables are only created with restoring initializers when
|
||||
# executing eagerly.
|
||||
self.assertEqual(14., self.evaluate(
|
||||
new_root.optimizer.get_slot(slot_name="m", var=new_root.var)))
|
||||
else:
|
||||
# Slot variables are not created eagerly when graph building.
|
||||
with self.assertRaises(KeyError):
|
||||
new_root.optimizer.get_slot(slot_name="m", var=new_root.var)
|
||||
variables = [new_root.var]
|
||||
gradients = [1.]
|
||||
train_op = new_root.optimizer.apply_gradients(zip(gradients, variables))
|
||||
# The slot variable now exists; restore() didn't create it, but we should
|
||||
# now have a restore op for it.
|
||||
slot_status.run_restore_ops()
|
||||
if not context.executing_eagerly():
|
||||
# The train op hasn't run when graph building, so the slot variable has
|
||||
# its restored value. It has run in eager, so the value will be different.
|
||||
self.assertEqual(14., self.evaluate(
|
||||
new_root.optimizer.get_slot(slot_name="m", var=new_root.var)))
|
||||
self.evaluate(train_op)
|
||||
slot_status.assert_consumed()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testOverlappingRestores(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
@ -1154,24 +669,6 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
status.run_restore_ops()
|
||||
self.assertEqual(4., self.evaluate(recreated_var1))
|
||||
|
||||
def testManySavesGraph(self):
|
||||
"""Saves after the first should not modify the graph."""
|
||||
with context.graph_mode():
|
||||
graph = ops.Graph()
|
||||
with graph.as_default(), self.session(graph):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
obj = trackable_utils.Checkpoint()
|
||||
obj.var = variables_lib.Variable(0., name="v")
|
||||
obj.opt = adam.Adam(0.1)
|
||||
variables = [obj.var]
|
||||
gradients = [1.]
|
||||
obj.opt.apply_gradients(zip(gradients, variables))
|
||||
self.evaluate(trackable_utils.gather_initializers(obj))
|
||||
obj.save(checkpoint_prefix)
|
||||
graph.finalize()
|
||||
obj.save(checkpoint_prefix)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCheckpointState(self):
|
||||
# No checkpoints are deleted by default
|
||||
@ -1237,146 +734,6 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
self.assertEqual(1, self.evaluate(checkpoint.var_1))
|
||||
self.assertEqual(0, self.evaluate(checkpoint.var_0))
|
||||
|
||||
def testManyRestoresGraph(self):
|
||||
"""Restores after the first should not modify the graph."""
|
||||
with context.graph_mode():
|
||||
graph = ops.Graph()
|
||||
with graph.as_default(), self.session(graph):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
obj = trackable_utils.Checkpoint()
|
||||
obj.var = variables_lib.Variable(0., name="v")
|
||||
obj.opt = adam.Adam(0.1)
|
||||
variables = [obj.var]
|
||||
gradients = [1.]
|
||||
obj.opt.apply_gradients(zip(gradients, variables))
|
||||
self.evaluate(trackable_utils.gather_initializers(obj))
|
||||
save_path = obj.save(checkpoint_prefix)
|
||||
obj.restore(save_path)
|
||||
graph.finalize()
|
||||
obj.restore(save_path)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_sequential(self):
|
||||
model = sequential.Sequential()
|
||||
checkpoint = trackable_utils.Checkpoint(model=model)
|
||||
model.add(core.Dense(4))
|
||||
second_dense = core.Dense(5)
|
||||
model.add(second_dense)
|
||||
model(constant_op.constant([[1.]]))
|
||||
checkpoint.restore(None).initialize_or_restore()
|
||||
self.evaluate(second_dense.bias.assign(
|
||||
constant_op.constant([1., 2., 3., 4., 5.])))
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
self.evaluate(second_dense.bias.assign(
|
||||
constant_op.constant([5., 6., 7., 8., 9.])))
|
||||
checkpoint.restore(save_path).assert_consumed().run_restore_ops()
|
||||
self.assertAllEqual([1., 2., 3., 4., 5.], self.evaluate(second_dense.bias))
|
||||
|
||||
deferred_sequential = sequential.Sequential()
|
||||
deferred_sequential_checkpoint = trackable_utils.Checkpoint(
|
||||
model=deferred_sequential)
|
||||
status = deferred_sequential_checkpoint.restore(save_path)
|
||||
deferred_sequential.add(core.Dense(4))
|
||||
deferred_second_dense = core.Dense(5)
|
||||
deferred_sequential.add(deferred_second_dense)
|
||||
deferred_sequential(constant_op.constant([[1.]]))
|
||||
status.run_restore_ops()
|
||||
self.assertAllEqual([1., 2., 3., 4., 5.],
|
||||
self.evaluate(deferred_second_dense.bias))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_initialize_if_not_restoring(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.Checkpoint(
|
||||
model=model) # Do not save the optimizer with the checkpoint.
|
||||
optimizer_checkpoint = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer)
|
||||
|
||||
checkpoint_path = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
def train_fn():
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
return optimizer.apply_gradients(zip(gradients, variables))
|
||||
if not context.executing_eagerly():
|
||||
train_fn = functools.partial(self.evaluate, train_fn())
|
||||
status.initialize_or_restore()
|
||||
# TODO(tanzheny): Add hyper variables to .variables(), and set them with
|
||||
# set_weights etc.
|
||||
variables_not_in_the_variables_property = [
|
||||
obj for obj in optimizer._hyper.values()
|
||||
if isinstance(obj, variables_lib.Variable)]
|
||||
self.evaluate([v.initializer for v
|
||||
in optimizer.variables()
|
||||
+ variables_not_in_the_variables_property])
|
||||
train_fn()
|
||||
model_save_path = root.save(file_prefix=checkpoint_prefix)
|
||||
self.evaluate(optimizer.beta_1.assign(42.))
|
||||
optimizer_save_path = optimizer_checkpoint.save(optimizer_only_prefix)
|
||||
del train_fn
|
||||
|
||||
# Restore into a graph with the optimizer
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
status = root.restore(save_path=model_save_path)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
def train_fn1():
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
return optimizer.apply_gradients(zip(gradients, variables))
|
||||
if not context.executing_eagerly():
|
||||
train_fn1 = functools.partial(self.evaluate, train_fn1())
|
||||
status.initialize_or_restore()
|
||||
train_fn1()
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_existing_objects_matched()
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_consumed()
|
||||
del train_fn1
|
||||
|
||||
# Make sure initialization doesn't clobber later restores
|
||||
with test_util.device(use_gpu=True):
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001, beta_1=1.0)
|
||||
root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
opt_root = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer)
|
||||
status = root.restore(save_path=model_save_path)
|
||||
init_only_optimizer_status = opt_root.restore(save_path=None)
|
||||
optimizer_status = opt_root.restore(save_path=optimizer_save_path)
|
||||
input_value = constant_op.constant([[3.]])
|
||||
def train_fn2():
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
return optimizer.apply_gradients(zip(gradients, variables))
|
||||
if not context.executing_eagerly():
|
||||
train_fn2 = functools.partial(self.evaluate, train_fn2())
|
||||
optimizer_status.run_restore_ops()
|
||||
status.initialize_or_restore()
|
||||
init_only_optimizer_status.initialize_or_restore()
|
||||
train_fn2()
|
||||
self.assertEqual(42., self.evaluate(optimizer.beta_1))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_restore_after_adding_empty_trackable_data_structure(self):
|
||||
model = NonLayerTrackable()
|
||||
@ -1439,75 +796,8 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
self.assertAllClose(self.evaluate(load_checkpoint.b), {"a": 2, "b": 3})
|
||||
|
||||
|
||||
class _ManualScope(tracking.AutoTrackable):
|
||||
|
||||
def __call__(self):
|
||||
with variable_scope.variable_scope("ManualScope") as vs:
|
||||
self.variable_scope = vs
|
||||
with trackable_utils.capture_dependencies(template=self):
|
||||
return self._build()
|
||||
|
||||
def _build(self):
|
||||
return variable_scope.get_variable(name="in_manual_scope", shape=[])
|
||||
|
||||
|
||||
class TemplateTests(parameterized.TestCase, test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_trackable_save_restore(self):
|
||||
|
||||
def _templated():
|
||||
v = variable_scope.get_variable(
|
||||
"v", shape=[1], initializer=init_ops.zeros_initializer(),
|
||||
use_resource=True)
|
||||
v2 = variable_scope.get_variable(
|
||||
"v2", shape=[1], initializer=init_ops.zeros_initializer(),
|
||||
use_resource=True)
|
||||
manual = _ManualScope()
|
||||
return v, v + 1., v2, manual, manual()
|
||||
|
||||
save_template = template.make_template("s1", _templated)
|
||||
v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
[id(v1_save), id(v2_save), id(manual_scope),
|
||||
id(manual_scope_v), id(save_template)],
|
||||
map(id, trackable_utils.list_objects(save_template)))
|
||||
manual_dep, = manual_scope._checkpoint_dependencies
|
||||
self.assertEqual("in_manual_scope", manual_dep.name)
|
||||
self.assertIs(manual_scope_v, manual_dep.ref)
|
||||
optimizer = adam.Adam(0.0)
|
||||
save_root = trackable_utils.Checkpoint(
|
||||
my_template=save_template, optimizer=optimizer)
|
||||
optimizer.minimize(v1_save.read_value,
|
||||
var_list=[v1_save])
|
||||
self.evaluate([v.initializer for v in save_template.variables])
|
||||
optimizer_variables = optimizer.variables() + list(
|
||||
optimizer._hyper.values())
|
||||
self.evaluate([v.initializer for v in optimizer_variables])
|
||||
self.evaluate(v1_save.assign([12.]))
|
||||
self.evaluate(v2_save.assign([14.]))
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
save_path = save_root.save(checkpoint_prefix)
|
||||
|
||||
load_template = template.make_template("s2", _templated)
|
||||
load_optimizer = adam.Adam(0.0)
|
||||
load_root = trackable_utils.Checkpoint(
|
||||
my_template=load_template, optimizer=load_optimizer)
|
||||
status = load_root.restore(save_path)
|
||||
var, var_plus_one, var2, _, _ = load_template()
|
||||
load_optimizer.minimize(var.read_value, var_list=[var])
|
||||
self.assertLen(load_template._checkpoint_dependencies, 3)
|
||||
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
|
||||
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
|
||||
self.assertEqual("ManualScope",
|
||||
load_template._checkpoint_dependencies[2].name)
|
||||
status.assert_consumed().run_restore_ops()
|
||||
self.assertAllEqual([12.], self.evaluate(var))
|
||||
self.assertAllEqual([13.], self.evaluate(var_plus_one))
|
||||
self.assertAllEqual([14.], self.evaluate(var2))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_trackable_save_restore_nested(self):
|
||||
|
||||
@ -1554,157 +844,6 @@ class TemplateTests(parameterized.TestCase, test.TestCase):
|
||||
self.assertAllEqual([25.], self.evaluate(v3))
|
||||
|
||||
|
||||
class CheckpointCompatibilityTests(test.TestCase):
|
||||
|
||||
def _initialized_model(self):
|
||||
input_value = constant_op.constant([[3.]])
|
||||
model = MyModel()
|
||||
optimizer = adam.Adam(0.001)
|
||||
root_trackable = trackable_utils.Checkpoint(
|
||||
optimizer=optimizer, model=model)
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = model(input_value)
|
||||
variables = model.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
train_op = optimizer.apply_gradients(zip(gradients, variables))
|
||||
self.evaluate(trackable_utils.gather_initializers(
|
||||
root_trackable))
|
||||
self.evaluate(train_op)
|
||||
# A regular variable, a slot variable, and a non-slot Optimizer variable
|
||||
# with known values to check when loading.
|
||||
self.evaluate(model._named_dense.bias.assign([1.]))
|
||||
self.evaluate(optimizer.get_slot(
|
||||
var=model._named_dense.bias, slot_name="m").assign([2.]))
|
||||
self.evaluate(optimizer.beta_1.assign(3.))
|
||||
return root_trackable
|
||||
|
||||
def _set_sentinels(self, root_trackable):
|
||||
self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
|
||||
self.evaluate(
|
||||
root_trackable.optimizer.get_slot(
|
||||
var=root_trackable.model._named_dense.bias, slot_name="m")
|
||||
.assign([102.]))
|
||||
self.evaluate(root_trackable.optimizer.beta_1.assign(103.))
|
||||
|
||||
def _check_sentinels(self, root_trackable):
|
||||
self.assertAllEqual(
|
||||
[1.], self.evaluate(root_trackable.model._named_dense.bias))
|
||||
self.assertAllEqual([2.], self.evaluate(
|
||||
root_trackable.optimizer.get_slot(
|
||||
var=root_trackable.model._named_dense.bias, slot_name="m")))
|
||||
self.assertAllEqual(3.,
|
||||
self.evaluate(root_trackable.optimizer.beta_1))
|
||||
|
||||
def _write_name_based_checkpoint(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
with context.graph_mode():
|
||||
save_graph = ops.Graph()
|
||||
with save_graph.as_default(), self.session(
|
||||
graph=save_graph) as session:
|
||||
root = self._initialized_model()
|
||||
name_saver = saver_lib.Saver()
|
||||
return name_saver.save(
|
||||
sess=session,
|
||||
save_path=checkpoint_prefix,
|
||||
global_step=root.optimizer.iterations)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLoadFromNameBasedSaver(self):
|
||||
"""Save a name-based checkpoint, load it using the object-based API."""
|
||||
with test_util.device(use_gpu=True):
|
||||
save_path = self._write_name_based_checkpoint()
|
||||
root = self._initialized_model()
|
||||
self._set_sentinels(root)
|
||||
with self.assertRaises(AssertionError):
|
||||
self._check_sentinels(root)
|
||||
object_saver = trackable_utils.TrackableSaver(
|
||||
graph_view.ObjectGraphView(root))
|
||||
self._set_sentinels(root)
|
||||
status = object_saver.restore(save_path)
|
||||
if context.executing_eagerly():
|
||||
self._check_sentinels(root)
|
||||
if context.executing_eagerly():
|
||||
status.assert_consumed()
|
||||
status.assert_existing_objects_matched()
|
||||
status.assert_nontrivial_match()
|
||||
else:
|
||||
# When graph building, we haven't read any keys, so we don't know
|
||||
# whether the restore will be complete.
|
||||
with self.assertRaisesRegexp(AssertionError, "not restored"):
|
||||
status.assert_consumed()
|
||||
with self.assertRaisesRegexp(AssertionError, "not restored"):
|
||||
status.assert_existing_objects_matched()
|
||||
with self.assertRaisesRegexp(AssertionError, "not restored"):
|
||||
status.assert_nontrivial_match()
|
||||
status.run_restore_ops()
|
||||
self._check_sentinels(root)
|
||||
self._set_sentinels(root)
|
||||
status = object_saver.restore(save_path)
|
||||
status.initialize_or_restore()
|
||||
status.assert_nontrivial_match()
|
||||
self._check_sentinels(root)
|
||||
# Check that there is no error when keys are missing from the name-based
|
||||
# checkpoint.
|
||||
root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable([1.])
|
||||
status = object_saver.restore(save_path)
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_existing_objects_matched()
|
||||
|
||||
def testSaveGraphLoadEager(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
with context.graph_mode():
|
||||
save_graph = ops.Graph()
|
||||
with save_graph.as_default(), self.session(
|
||||
graph=save_graph):
|
||||
root = self._initialized_model()
|
||||
save_path = root.save(file_prefix=checkpoint_prefix)
|
||||
with context.eager_mode():
|
||||
root = self._initialized_model()
|
||||
self._set_sentinels(root)
|
||||
root.restore(save_path).assert_consumed()
|
||||
self._check_sentinels(root)
|
||||
|
||||
def testSaveEagerLoadGraph(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
with context.eager_mode():
|
||||
root = self._initialized_model()
|
||||
save_path = root.save(file_prefix=checkpoint_prefix)
|
||||
with context.graph_mode():
|
||||
save_graph = ops.Graph()
|
||||
with save_graph.as_default(), self.session(
|
||||
graph=save_graph):
|
||||
root = self._initialized_model()
|
||||
self._set_sentinels(root)
|
||||
root.restore(save_path).assert_consumed().run_restore_ops()
|
||||
self._check_sentinels(root)
|
||||
|
||||
def testIgnoreSaveCounter(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
with self.cached_session() as session:
|
||||
# Create and save a model using Saver() before using a Checkpoint. This
|
||||
# generates a snapshot without the Checkpoint's `save_counter`.
|
||||
model = sequential.Sequential()
|
||||
model.add(core.Flatten(input_shape=(1,)))
|
||||
model.add(core.Dense(1))
|
||||
name_saver = saver_lib.Saver(model.trainable_variables)
|
||||
save_path = name_saver.save(
|
||||
sess=session, save_path=checkpoint_prefix, global_step=1)
|
||||
# Checkpoint.restore must successfully load that checkpoint.
|
||||
ckpt = trackable_utils.Checkpoint(model=model)
|
||||
status = ckpt.restore(save_path)
|
||||
status.assert_existing_objects_matched()
|
||||
# It should, however, refuse to load a checkpoint where an unrelated
|
||||
# `save_counter` variable is missing.
|
||||
model.layers[1].var = variables_lib.Variable(0., name="save_counter")
|
||||
status = ckpt.restore(save_path)
|
||||
with self.assertRaises(AssertionError):
|
||||
status.assert_existing_objects_matched()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user