611 lines
20 KiB
Python
611 lines
20 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
|
|
from absl.testing import parameterized
|
|
import numpy
|
|
import six
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import test
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.keras.engine import sequential
|
|
from tensorflow.python.keras.engine import training
|
|
from tensorflow.python.keras.layers import core
|
|
from tensorflow.python.keras.layers import normalization
|
|
from tensorflow.python.module import module
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.training.tracking import base
|
|
from tensorflow.python.training.tracking import data_structures
|
|
from tensorflow.python.training.tracking import tracking
|
|
from tensorflow.python.training.tracking import util
|
|
|
|
|
|
class HasList(training.Model):
|
|
|
|
def __init__(self):
|
|
super(HasList, self).__init__()
|
|
self.layer_list = data_structures.List([core.Dense(3)])
|
|
self.layer_list.append(core.Dense(4))
|
|
self.layer_list.extend(
|
|
[core.Dense(5),
|
|
core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
|
|
self.layer_list += [
|
|
core.Dense(7, bias_regularizer=math_ops.reduce_sum),
|
|
core.Dense(8)
|
|
]
|
|
self.layer_list += (
|
|
data_structures.List([core.Dense(9)]) + data_structures.List(
|
|
[core.Dense(10)]))
|
|
self.layer_list.extend(
|
|
data_structures.List(
|
|
list([core.Dense(11)]) + [core.Dense(12)]))
|
|
self.layers_with_updates = data_structures.List(
|
|
(normalization.BatchNormalization(),))
|
|
|
|
def call(self, x):
|
|
aggregation = 0.
|
|
for l in self.layer_list:
|
|
x = l(x)
|
|
aggregation += math_ops.reduce_sum(x)
|
|
bn, = self.layers_with_updates
|
|
return bn(x) / aggregation
|
|
|
|
|
|
class ListTests(test.TestCase):
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testTracking(self):
|
|
model = HasList()
|
|
output = model(array_ops.ones([32, 2]))
|
|
self.assertAllEqual([32, 12], output.shape)
|
|
self.assertEqual(11, len(model.layers))
|
|
self.assertEqual(10, len(model.layer_list.layers))
|
|
six.assertCountEqual(
|
|
self,
|
|
model.layers,
|
|
model.layer_list.layers + model.layers_with_updates)
|
|
for index in range(10):
|
|
self.assertEqual(3 + index, model.layer_list.layers[index].units)
|
|
self.assertEqual(2, len(model._checkpoint_dependencies))
|
|
self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
|
|
self.assertIs(model.layers_with_updates,
|
|
model._checkpoint_dependencies[1].ref)
|
|
self.assertEqual(
|
|
10, len(model._checkpoint_dependencies[0].ref._checkpoint_dependencies))
|
|
self.evaluate([v.initializer for v in model.variables])
|
|
self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]]))
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
model.save_weights(save_path)
|
|
self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
|
|
model.load_weights(save_path)
|
|
self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
|
|
self.evaluate(model.variables[0]))
|
|
v = variables.Variable(1.)
|
|
model.var_list = [v]
|
|
self.assertIn(v, model.variables)
|
|
self.assertIn(v, model.trainable_variables)
|
|
self.assertNotIn(v, model.non_trainable_variables)
|
|
self.assertIn(model.layer_list[0].trainable_weights[0],
|
|
model.trainable_weights)
|
|
|
|
def testSubModelTracking(self):
|
|
model = training.Model()
|
|
model.v = variables.Variable(1.)
|
|
self.assertIn(model.v, model.trainable_weights)
|
|
model2 = training.Model()
|
|
model2.m = [model]
|
|
self.assertIn(model.v, model2.trainable_weights)
|
|
|
|
def testSubSequentialTracking(self):
|
|
|
|
class _Subclassed(training.Model):
|
|
|
|
def __init__(self, wrapped):
|
|
super(_Subclassed, self).__init__()
|
|
self._wrapped = wrapped
|
|
|
|
def call(self, x):
|
|
return self._wrapped(x)
|
|
|
|
model = sequential.Sequential()
|
|
layer = core.Dense(1)
|
|
model.add(layer)
|
|
model2 = _Subclassed(model)
|
|
model2(array_ops.ones([1, 2]))
|
|
model2.m = [model]
|
|
self.assertIn(layer.kernel, model2.trainable_weights)
|
|
|
|
def testLayerTrackedThroughSequential(self):
|
|
class AttrDict(dict):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(AttrDict, self).__init__(*args, **kwargs)
|
|
self.__dict__ = self
|
|
|
|
def ffnet(layer_sizes, name):
|
|
ff = sequential.Sequential(name=name)
|
|
for i, width in enumerate(layer_sizes):
|
|
ff.add(core.Dense(
|
|
width,
|
|
activation=("relu" if i < len(layer_sizes)-1 else None)))
|
|
return ff
|
|
|
|
class MyModel2(training.Model):
|
|
|
|
def __init__(self, config, name="my_model_2"):
|
|
super(MyModel2, self).__init__(name=name)
|
|
self._num_tokens = config.num_tokens
|
|
|
|
# list of sub-models
|
|
self._ffnet = [ffnet(config.module_layers + (self._num_tokens,), "ff")]
|
|
|
|
def null_input(self):
|
|
return array_ops.zeros([1, self._num_tokens], dtype=dtypes.float32)
|
|
|
|
def call(self, input_, module_index=None):
|
|
return self._ffnet[0](input_)
|
|
|
|
m2 = MyModel2(AttrDict(
|
|
num_tokens=5,
|
|
module_layers=(50, 30)))
|
|
|
|
# Construct
|
|
m2(m2.null_input())
|
|
self.assertLen(m2.trainable_variables, 6)
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testUpdatesForwarded(self):
|
|
with context.graph_mode():
|
|
model = HasList()
|
|
model_input = array_ops.ones([32, 2])
|
|
model(model_input)
|
|
self.assertGreater(len(model.layers_with_updates[0].updates), 0)
|
|
self.assertEqual(set(model.layers_with_updates[0].updates),
|
|
set(model.updates))
|
|
|
|
with context.eager_mode():
|
|
model = HasList()
|
|
model_input = array_ops.ones([32, 2])
|
|
model(model_input)
|
|
self.assertEqual(0, len(model.updates))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testLossesForwarded(self):
|
|
model = HasList()
|
|
model_input = array_ops.ones([32, 2])
|
|
model(model_input)
|
|
self.assertEqual(2, len(model.losses))
|
|
|
|
def testModelContainersCompareEqual(self):
|
|
class HasEqualContainers(training.Model):
|
|
|
|
def __init__(self):
|
|
super(HasEqualContainers, self).__init__()
|
|
self.l1 = []
|
|
self.l2 = []
|
|
|
|
model = HasEqualContainers()
|
|
first_layer = HasEqualContainers()
|
|
model.l1.append(first_layer)
|
|
second_layer = HasEqualContainers()
|
|
model.l2.append(second_layer)
|
|
self.assertEqual([first_layer, second_layer], model.layers)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testTensorConversion(self):
|
|
|
|
class ListToTensor(training.Model):
|
|
|
|
def __init__(self):
|
|
super(ListToTensor, self).__init__()
|
|
self.l = [1., 2., 3.]
|
|
|
|
self.assertAllEqual(
|
|
[1., 2., 3.],
|
|
self.evaluate(constant_op.constant(ListToTensor().l)))
|
|
|
|
self.assertAllEqual(
|
|
[1., 2., 3.],
|
|
self.evaluate(array_ops.pack(ListToTensor().l)))
|
|
|
|
|
|
class ListWrapperTest(test.TestCase):
|
|
|
|
def testLayerCollectionWithExternalMutation(self):
|
|
l = []
|
|
l_wrapper = data_structures.ListWrapper(l)
|
|
layer = core.Dense(1)
|
|
l.append(layer)
|
|
self.assertEqual([layer], l_wrapper.layers)
|
|
|
|
|
|
class HasMapping(training.Model):
|
|
|
|
def __init__(self):
|
|
super(HasMapping, self).__init__()
|
|
self.layer_dict = data_structures.Mapping(output=core.Dense(7))
|
|
self.layer_dict["norm"] = data_structures.List()
|
|
self.layer_dict["dense"] = data_structures.List()
|
|
self.layer_dict["dense"].extend(
|
|
[core.Dense(5),
|
|
core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
|
|
self.layer_dict["norm"].append(
|
|
normalization.BatchNormalization())
|
|
self.layer_dict["norm"].append(
|
|
normalization.BatchNormalization())
|
|
|
|
def call(self, x):
|
|
aggregation = 0.
|
|
for norm, dense in zip(self.layer_dict["norm"], self.layer_dict["dense"]):
|
|
x = norm(dense(x))
|
|
aggregation += math_ops.reduce_sum(x)
|
|
return self.layer_dict["output"](x) / aggregation
|
|
|
|
|
|
class MappingTests(test.TestCase):
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testTracking(self):
|
|
model = HasMapping()
|
|
output = model(array_ops.ones([32, 2]))
|
|
self.assertAllEqual([32, 7], output.shape.as_list())
|
|
self.assertEqual(5, len(model.layers))
|
|
six.assertCountEqual(self, model.layers, model.layer_dict.layers)
|
|
self.assertEqual(1, len(model._checkpoint_dependencies))
|
|
self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref)
|
|
self.evaluate([v.initializer for v in model.variables])
|
|
test_var = model.layer_dict["output"].kernel
|
|
self.evaluate(test_var.assign(array_ops.ones([6, 7])))
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
model.save_weights(save_path)
|
|
self.evaluate(test_var.assign(array_ops.zeros([6, 7])))
|
|
model.load_weights(save_path)
|
|
self.assertAllEqual(numpy.ones([6, 7]),
|
|
self.evaluate(test_var))
|
|
|
|
def testLayerCollectionWithExternalMutation(self):
|
|
d = {}
|
|
root = tracking.AutoTrackable()
|
|
root.wrapper = d
|
|
self.assertEqual([], root.wrapper.layers)
|
|
self.assertEqual([], root.wrapper.trainable_weights)
|
|
layer1 = core.Dense(1)
|
|
layer2 = core.Dense(1)
|
|
d["a"] = layer1
|
|
d["b"] = layer2
|
|
self.assertEqual([layer1, layer2], root.wrapper.layers)
|
|
# The layers have still not created variables
|
|
self.assertEqual([], root.wrapper.trainable_weights)
|
|
|
|
def testDictWrapperBadKeys(self):
|
|
a = tracking.AutoTrackable()
|
|
a.d = {}
|
|
a.d[1] = data_structures.List()
|
|
model = training.Model()
|
|
model.sub = a
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
with self.assertRaisesRegexp(ValueError, "non-string key"):
|
|
model.save_weights(save_path)
|
|
|
|
def testDictWrapperNoDependency(self):
|
|
a = tracking.AutoTrackable()
|
|
a.d = data_structures.NoDependency({})
|
|
a.d[1] = [3]
|
|
self.assertEqual([a], util.list_objects(a))
|
|
model = training.Model()
|
|
model.sub = a
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
model.save_weights(save_path)
|
|
model.load_weights(save_path)
|
|
|
|
def testNonStringKeyNotTrackableValue(self):
|
|
a = tracking.AutoTrackable()
|
|
a.d = {}
|
|
a.d["a"] = [3]
|
|
a.d[1] = data_structures.NoDependency([3])
|
|
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
|
model = training.Model()
|
|
model.sub = a
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
model.save_weights(save_path)
|
|
model.load_weights(save_path)
|
|
|
|
def testNonAppendNotTrackable(self):
|
|
# Non-append mutations (deleting or overwriting values) are OK when the
|
|
# values aren't tracked.
|
|
a = tracking.AutoTrackable()
|
|
a.d = {}
|
|
a.d["a"] = [3]
|
|
a.d[1] = 3
|
|
a.d[1] = 2
|
|
self.assertEqual(2, a.d[1])
|
|
del a.d[1]
|
|
a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
|
|
second = tracking.AutoTrackable()
|
|
a.d[2] = data_structures.NoDependency(second)
|
|
self.assertIs(second, a.d[2])
|
|
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
|
model = training.Model()
|
|
model.sub = a
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
model.save_weights(save_path)
|
|
model.load_weights(save_path)
|
|
|
|
def testPopNoSave(self):
|
|
model = training.Model()
|
|
model.d = {}
|
|
model.d["a"] = []
|
|
model.d.pop("a")
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
with self.assertRaisesRegexp(ValueError, "Unable to save"):
|
|
model.save_weights(save_path)
|
|
|
|
def testExternalModificationNoSave(self):
|
|
model = training.Model()
|
|
external_reference = {}
|
|
model.d = external_reference
|
|
external_reference["a"] = []
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"):
|
|
model.save_weights(save_path)
|
|
|
|
def testOverwriteCanStillSave(self):
|
|
model = training.Model()
|
|
model.d = {}
|
|
model.d["a"] = {}
|
|
model.d["a"] = {}
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
model.save_weights(save_path)
|
|
|
|
def testIter(self):
|
|
model = training.Model()
|
|
model.d = {1: 3}
|
|
model.d[1] = 3
|
|
self.assertEqual([1], list(model.d))
|
|
new_dict = {}
|
|
# This update() is super tricky. If the dict wrapper subclasses dict,
|
|
# CPython will access its storage directly instead of calling any
|
|
# methods/properties on the object. So the options are either not to
|
|
# subclass dict (in which case update will call normal iter methods, but the
|
|
# object won't pass isinstance checks) or to subclass dict and keep that
|
|
# storage updated (no shadowing all its methods like ListWrapper).
|
|
new_dict.update(model.d)
|
|
self.assertEqual({1: 3}, new_dict)
|
|
|
|
|
|
class HasTuple(training.Model):
|
|
|
|
def __init__(self):
|
|
super(HasTuple, self).__init__()
|
|
self.layer_list = (
|
|
core.Dense(3), core.Dense(4),
|
|
core.Dense(5, kernel_regularizer=math_ops.reduce_sum))
|
|
self.layers_with_updates = (normalization.BatchNormalization(),)
|
|
|
|
def call(self, x):
|
|
aggregation = 0.
|
|
for l in self.layer_list:
|
|
x = l(x)
|
|
aggregation += math_ops.reduce_sum(x)
|
|
bn, = self.layers_with_updates
|
|
return bn(x) / aggregation
|
|
|
|
|
|
class TupleTests(test.TestCase, parameterized.TestCase):
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testTracking(self):
|
|
model = HasTuple()
|
|
output = model(array_ops.ones([32, 2]))
|
|
self.assertAllEqual([32, 5], output.shape.as_list())
|
|
self.assertLen(model.layers, 4)
|
|
self.assertLen(model.layer_list.layers, 3)
|
|
six.assertCountEqual(
|
|
self,
|
|
model.layers,
|
|
tuple(model.layer_list.layers) + model.layers_with_updates)
|
|
self.assertEqual(3, model.layer_list.layers[0].units)
|
|
self.assertEqual(4, model.layer_list.layers[1].units)
|
|
self.assertEqual(5, model.layer_list.layers[2].units)
|
|
self.assertLen(model._checkpoint_dependencies, 2)
|
|
self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
|
|
self.assertIs(model.layers_with_updates,
|
|
model._checkpoint_dependencies[1].ref)
|
|
self.assertLen(
|
|
model._checkpoint_dependencies[0].ref._checkpoint_dependencies, 3)
|
|
self.evaluate([v.initializer for v in model.variables])
|
|
self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]]))
|
|
save_path = os.path.join(self.get_temp_dir(), "ckpt")
|
|
model.save_weights(save_path)
|
|
self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
|
|
model.load_weights(save_path)
|
|
self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
|
|
self.evaluate(model.variables[0]))
|
|
v = variables.Variable(1.)
|
|
model.var_list = (v,)
|
|
self.assertIn(id(v), [id(obj) for obj in model.variables])
|
|
self.assertIn(id(v), [id(obj) for obj in model.trainable_variables])
|
|
self.assertNotIn(id(v), [id(obj) for obj in model.non_trainable_variables])
|
|
self.assertIn(id(model.layer_list[0].trainable_weights[0]),
|
|
[id(obj) for obj in model.trainable_weights])
|
|
|
|
@parameterized.named_parameters(
|
|
("Module", module.Module),
|
|
("Model", training.Model),
|
|
)
|
|
def testSubModelTracking(self, module_subclass):
|
|
model = module_subclass()
|
|
model.v = variables.Variable(1.)
|
|
self.assertIn(model.v, model.trainable_variables)
|
|
model2 = module_subclass()
|
|
model2.m = (model,)
|
|
self.assertIn(model.v, model2.trainable_variables)
|
|
|
|
def testSubSequentialTracking(self):
|
|
|
|
class _Subclassed(training.Model):
|
|
|
|
def __init__(self, wrapped):
|
|
super(_Subclassed, self).__init__()
|
|
self._wrapped = wrapped
|
|
|
|
def call(self, x):
|
|
return self._wrapped(x)
|
|
|
|
model = sequential.Sequential()
|
|
layer = core.Dense(1)
|
|
model.add(layer)
|
|
model2 = _Subclassed(model)
|
|
model2(array_ops.ones([1, 2]))
|
|
model2.m = (model,)
|
|
self.assertIn(layer.kernel, model2.trainable_weights)
|
|
|
|
def testUpdatesForwarded(self):
|
|
with ops.Graph().as_default():
|
|
model = HasTuple()
|
|
model_input = array_ops.ones([32, 2])
|
|
model(model_input)
|
|
self.assertNotEmpty(model.layers_with_updates[0].updates)
|
|
self.assertEqual(set(model.layers_with_updates[0].updates),
|
|
set(model.updates))
|
|
|
|
model = HasTuple()
|
|
model_input = array_ops.ones([32, 2])
|
|
model(model_input)
|
|
self.assertEmpty(model.updates)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testLossesForwarded(self):
|
|
model = HasTuple()
|
|
model_input = array_ops.ones([32, 2])
|
|
model(model_input)
|
|
self.assertLen(model.losses, 1)
|
|
|
|
def testModelContainersCompareEqual(self):
|
|
class HasEqualContainers(training.Model):
|
|
|
|
def __init__(self):
|
|
super(HasEqualContainers, self).__init__()
|
|
self.l1 = ()
|
|
self.l2 = ()
|
|
|
|
model = HasEqualContainers()
|
|
first_layer = HasEqualContainers()
|
|
model.l1 = (first_layer,)
|
|
second_layer = HasEqualContainers()
|
|
model.l2 = (second_layer,)
|
|
self.assertEqual((first_layer,), model.l1)
|
|
d = {model.l1: 1, model.l2: 2}
|
|
self.assertEqual(1, d[model.l1])
|
|
self.assertEqual(1, d[(first_layer,)])
|
|
self.assertEqual(2, d[model.l2])
|
|
self.assertEqual(2, d[(second_layer,)])
|
|
self.assertEqual([first_layer, second_layer], model.layers)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testTensorConversion(self):
|
|
|
|
class TupleToTensor(training.Model):
|
|
|
|
def __init__(self):
|
|
super(TupleToTensor, self).__init__()
|
|
self.l = (1., 2., 3.)
|
|
|
|
self.assertAllEqual(
|
|
(1., 2., 3.),
|
|
self.evaluate(constant_op.constant(TupleToTensor().l)))
|
|
|
|
self.assertAllEqual(
|
|
(1., 2., 3.),
|
|
self.evaluate(array_ops.pack(TupleToTensor().l)))
|
|
|
|
|
|
class InterfaceTests(test.TestCase):
|
|
|
|
def testNoDependency(self):
|
|
root = tracking.AutoTrackable()
|
|
hasdep = tracking.AutoTrackable()
|
|
root.hasdep = hasdep
|
|
nodep = tracking.AutoTrackable()
|
|
root.nodep = data_structures.NoDependency(nodep)
|
|
self.assertEqual(1, len(root._checkpoint_dependencies))
|
|
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
|
|
self.assertIs(root.hasdep, hasdep)
|
|
self.assertIs(root.nodep, nodep)
|
|
|
|
class NoDependencyModel(training.Model):
|
|
|
|
@base.no_automatic_dependency_tracking
|
|
def __init__(self):
|
|
super(NoDependencyModel, self).__init__()
|
|
self.a = []
|
|
self.b = tracking.AutoTrackable()
|
|
|
|
nodeps = NoDependencyModel()
|
|
self.assertEqual([nodeps], util.list_objects(nodeps))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testDictionariesBasic(self):
|
|
a = training.Model()
|
|
b = training.Model()
|
|
a.attribute = {"b": b}
|
|
c = training.Model()
|
|
a.attribute["c"] = []
|
|
a.attribute["c"].append(c)
|
|
a_deps = util.list_objects(a)
|
|
self.assertIn(b, a_deps)
|
|
self.assertIn(c, a_deps)
|
|
self.assertIs(b, a.attribute["b"])
|
|
six.assertCountEqual(
|
|
self,
|
|
["b", "c"],
|
|
[dep.name for dep in a.attribute._checkpoint_dependencies])
|
|
self.assertEqual([b, c], a.layers)
|
|
self.assertEqual([b, c], a.attribute.layers)
|
|
self.assertEqual([c], a.attribute["c"].layers)
|
|
checkpoint = util.Checkpoint(a=a)
|
|
save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
|
with self.cached_session():
|
|
checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testNoDepList(self):
|
|
a = training.Model()
|
|
a.l1 = data_structures.NoDependency([])
|
|
a.l1.insert(1, 0)
|
|
self.assertIsInstance(a.l1, list)
|
|
checkpoint = util.Checkpoint(a=a)
|
|
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
|
a.l2 = []
|
|
a.l2.insert(1, module.Module())
|
|
with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
|
|
checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|