Remove Autotrackable
from the Keras test files and replacing it with Module
.
PiperOrigin-RevId: 321296275 Change-Id: Idba3ced8d12b163daa254f04535a4542a1d9a45b
This commit is contained in:
parent
41339588d9
commit
83a712dc53
@ -29,12 +29,12 @@ from tensorflow.python.framework import convert_to_constants
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model.load import load
|
||||
from tensorflow.python.saved_model.save import save
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class VariablesToConstantsTest(test.TestCase):
|
||||
root: AutoTrackable object with original ConcreteFunction.
|
||||
output_func: frozen ConcreteFunction.
|
||||
"""
|
||||
root = tracking.AutoTrackable()
|
||||
root = module.Module()
|
||||
root.f = model
|
||||
input_func = root.f.get_concrete_function()
|
||||
|
||||
@ -91,7 +91,7 @@ class VariablesToConstantsTest(test.TestCase):
|
||||
|
||||
# Save the converted ConcreteFunction as a signature.
|
||||
save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
|
||||
root = tracking.AutoTrackable()
|
||||
root = module.Module()
|
||||
root.f = converted_concrete_func
|
||||
save(root, save_dir, {"mykey": converted_concrete_func})
|
||||
|
||||
|
@ -27,16 +27,16 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import saver as saver_module
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.tracking import tracking as trackable_tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
class NonLayerTrackable(trackable_tracking.AutoTrackable):
|
||||
class NonLayerTrackable(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(NonLayerTrackable, self).__init__()
|
||||
|
@ -39,7 +39,6 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
@ -290,7 +289,7 @@ class MappingTests(keras_parameterized.TestCase):
|
||||
|
||||
def testLayerCollectionWithExternalMutation(self):
|
||||
d = {}
|
||||
root = tracking.AutoTrackable()
|
||||
root = module.Module()
|
||||
root.wrapper = d
|
||||
self.assertEqual([], root.wrapper.layers)
|
||||
self.assertEqual([], root.wrapper.trainable_weights)
|
||||
@ -303,7 +302,7 @@ class MappingTests(keras_parameterized.TestCase):
|
||||
self.assertEqual([], root.wrapper.trainable_weights)
|
||||
|
||||
def testDictWrapperBadKeys(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a = module.Module()
|
||||
a.d = {}
|
||||
a.d[1] = data_structures.List()
|
||||
model = training.Model()
|
||||
@ -313,7 +312,7 @@ class MappingTests(keras_parameterized.TestCase):
|
||||
model.save_weights(save_path)
|
||||
|
||||
def testDictWrapperNoDependency(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a = module.Module()
|
||||
a.d = data_structures.NoDependency({})
|
||||
a.d[1] = [3]
|
||||
self.assertEqual([a], util.list_objects(a))
|
||||
@ -324,7 +323,7 @@ class MappingTests(keras_parameterized.TestCase):
|
||||
model.load_weights(save_path)
|
||||
|
||||
def testNonStringKeyNotTrackableValue(self):
|
||||
a = tracking.AutoTrackable()
|
||||
a = module.Module()
|
||||
a.d = {}
|
||||
a.d["a"] = [3]
|
||||
a.d[1] = data_structures.NoDependency([3])
|
||||
@ -338,15 +337,15 @@ class MappingTests(keras_parameterized.TestCase):
|
||||
def testNonAppendNotTrackable(self):
|
||||
# Non-append mutations (deleting or overwriting values) are OK when the
|
||||
# values aren't tracked.
|
||||
a = tracking.AutoTrackable()
|
||||
a = module.Module()
|
||||
a.d = {}
|
||||
a.d["a"] = [3]
|
||||
a.d[1] = 3
|
||||
a.d[1] = 2
|
||||
self.assertEqual(2, a.d[1])
|
||||
del a.d[1]
|
||||
a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
|
||||
second = tracking.AutoTrackable()
|
||||
a.d[2] = data_structures.NoDependency(module.Module())
|
||||
second = module.Module()
|
||||
a.d[2] = data_structures.NoDependency(second)
|
||||
self.assertIs(second, a.d[2])
|
||||
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
||||
@ -550,10 +549,10 @@ class TupleTests(keras_parameterized.TestCase):
|
||||
class InterfaceTests(keras_parameterized.TestCase):
|
||||
|
||||
def testNoDependency(self):
|
||||
root = tracking.AutoTrackable()
|
||||
hasdep = tracking.AutoTrackable()
|
||||
root = module.Module()
|
||||
hasdep = module.Module()
|
||||
root.hasdep = hasdep
|
||||
nodep = tracking.AutoTrackable()
|
||||
nodep = module.Module()
|
||||
root.nodep = data_structures.NoDependency(nodep)
|
||||
self.assertEqual(1, len(root._checkpoint_dependencies))
|
||||
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
|
||||
@ -566,7 +565,7 @@ class InterfaceTests(keras_parameterized.TestCase):
|
||||
def __init__(self):
|
||||
super(NoDependencyModel, self).__init__()
|
||||
self.a = []
|
||||
self.b = tracking.AutoTrackable()
|
||||
self.b = module.Module()
|
||||
|
||||
nodeps = NoDependencyModel()
|
||||
self.assertEqual([nodeps], util.list_objects(nodeps))
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -48,7 +49,6 @@ from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
@ -68,7 +68,7 @@ class MyModel(training.Model):
|
||||
return ret
|
||||
|
||||
|
||||
class NonLayerTrackable(tracking.AutoTrackable):
|
||||
class NonLayerTrackable(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(NonLayerTrackable, self).__init__()
|
||||
@ -709,7 +709,7 @@ class CheckpointingTests(keras_parameterized.TestCase):
|
||||
self.assertEqual(42., self.evaluate(optimizer.beta_1))
|
||||
|
||||
|
||||
class _ManualScope(tracking.AutoTrackable):
|
||||
class _ManualScope(module.Module):
|
||||
|
||||
def __call__(self):
|
||||
with variable_scope.variable_scope("ManualScope") as vs:
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import adam
|
||||
@ -42,11 +43,10 @@ from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
class NonLayerTrackable(tracking.AutoTrackable):
|
||||
class NonLayerTrackable(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(NonLayerTrackable, self).__init__()
|
||||
@ -460,7 +460,7 @@ class CheckpointingTests(keras_parameterized.TestCase):
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
def _get_checkpoint_name(self, name):
|
||||
root = tracking.AutoTrackable()
|
||||
root = module.Module()
|
||||
trackable_utils.add_variable(
|
||||
root, name=name, shape=[1, 2], dtype=dtypes.float64)
|
||||
(named_variable,), _, _ = trackable_utils._serialize_object_graph(
|
||||
|
@ -23,13 +23,13 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
class NonLayerTrackable(tracking.AutoTrackable):
|
||||
class NonLayerTrackable(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(NonLayerTrackable, self).__init__()
|
||||
|
Loading…
x
Reference in New Issue
Block a user