Remove Autotrackable from the Keras test files and replacing it with Module.

PiperOrigin-RevId: 321296275
Change-Id: Idba3ced8d12b163daa254f04535a4542a1d9a45b
This commit is contained in:
Pavithra Vijay 2020-07-14 21:48:19 -07:00 committed by TensorFlower Gardener
parent 41339588d9
commit 83a712dc53
6 changed files with 24 additions and 25 deletions

View File

@ -29,12 +29,12 @@ from tensorflow.python.framework import convert_to_constants
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model.load import load
from tensorflow.python.saved_model.save import save
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest
@ -50,7 +50,7 @@ class VariablesToConstantsTest(test.TestCase):
root: AutoTrackable object with original ConcreteFunction.
output_func: frozen ConcreteFunction.
"""
root = tracking.AutoTrackable()
root = module.Module()
root.f = model
input_func = root.f.get_concrete_function()
@ -91,7 +91,7 @@ class VariablesToConstantsTest(test.TestCase):
# Save the converted ConcreteFunction as a signature.
save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
root = tracking.AutoTrackable()
root = module.Module()
root.f = converted_concrete_func
save(root, save_dir, {"mykey": converted_concrete_func})

View File

@ -27,16 +27,16 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.module import module
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import training_util
from tensorflow.python.training.tracking import tracking as trackable_tracking
from tensorflow.python.training.tracking import util as trackable_utils
class NonLayerTrackable(trackable_tracking.AutoTrackable):
class NonLayerTrackable(module.Module):
def __init__(self):
super(NonLayerTrackable, self).__init__()

View File

@ -39,7 +39,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
@ -290,7 +289,7 @@ class MappingTests(keras_parameterized.TestCase):
def testLayerCollectionWithExternalMutation(self):
d = {}
root = tracking.AutoTrackable()
root = module.Module()
root.wrapper = d
self.assertEqual([], root.wrapper.layers)
self.assertEqual([], root.wrapper.trainable_weights)
@ -303,7 +302,7 @@ class MappingTests(keras_parameterized.TestCase):
self.assertEqual([], root.wrapper.trainable_weights)
def testDictWrapperBadKeys(self):
a = tracking.AutoTrackable()
a = module.Module()
a.d = {}
a.d[1] = data_structures.List()
model = training.Model()
@ -313,7 +312,7 @@ class MappingTests(keras_parameterized.TestCase):
model.save_weights(save_path)
def testDictWrapperNoDependency(self):
a = tracking.AutoTrackable()
a = module.Module()
a.d = data_structures.NoDependency({})
a.d[1] = [3]
self.assertEqual([a], util.list_objects(a))
@ -324,7 +323,7 @@ class MappingTests(keras_parameterized.TestCase):
model.load_weights(save_path)
def testNonStringKeyNotTrackableValue(self):
a = tracking.AutoTrackable()
a = module.Module()
a.d = {}
a.d["a"] = [3]
a.d[1] = data_structures.NoDependency([3])
@ -338,15 +337,15 @@ class MappingTests(keras_parameterized.TestCase):
def testNonAppendNotTrackable(self):
# Non-append mutations (deleting or overwriting values) are OK when the
# values aren't tracked.
a = tracking.AutoTrackable()
a = module.Module()
a.d = {}
a.d["a"] = [3]
a.d[1] = 3
a.d[1] = 2
self.assertEqual(2, a.d[1])
del a.d[1]
a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
second = tracking.AutoTrackable()
a.d[2] = data_structures.NoDependency(module.Module())
second = module.Module()
a.d[2] = data_structures.NoDependency(second)
self.assertIs(second, a.d[2])
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
@ -550,10 +549,10 @@ class TupleTests(keras_parameterized.TestCase):
class InterfaceTests(keras_parameterized.TestCase):
def testNoDependency(self):
root = tracking.AutoTrackable()
hasdep = tracking.AutoTrackable()
root = module.Module()
hasdep = module.Module()
root.hasdep = hasdep
nodep = tracking.AutoTrackable()
nodep = module.Module()
root.nodep = data_structures.NoDependency(nodep)
self.assertEqual(1, len(root._checkpoint_dependencies))
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
@ -566,7 +565,7 @@ class InterfaceTests(keras_parameterized.TestCase):
def __init__(self):
super(NoDependencyModel, self).__init__()
self.a = []
self.b = tracking.AutoTrackable()
self.b = module.Module()
nodeps = NoDependencyModel()
self.assertEqual([nodeps], util.list_objects(nodeps))

View File

@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.module import module
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
@ -48,7 +49,6 @@ from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util as trackable_utils
@ -68,7 +68,7 @@ class MyModel(training.Model):
return ret
class NonLayerTrackable(tracking.AutoTrackable):
class NonLayerTrackable(module.Module):
def __init__(self):
super(NonLayerTrackable, self).__init__()
@ -709,7 +709,7 @@ class CheckpointingTests(keras_parameterized.TestCase):
self.assertEqual(42., self.evaluate(optimizer.beta_1))
class _ManualScope(tracking.AutoTrackable):
class _ManualScope(module.Module):
def __call__(self):
with variable_scope.variable_scope("ManualScope") as vs:

View File

@ -35,6 +35,7 @@ from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.module import module
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import adam
@ -42,11 +43,10 @@ from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util as trackable_utils
class NonLayerTrackable(tracking.AutoTrackable):
class NonLayerTrackable(module.Module):
def __init__(self):
super(NonLayerTrackable, self).__init__()
@ -460,7 +460,7 @@ class CheckpointingTests(keras_parameterized.TestCase):
# pylint: enable=cell-var-from-loop
def _get_checkpoint_name(self, name):
root = tracking.AutoTrackable()
root = module.Module()
trackable_utils.add_variable(
root, name=name, shape=[1, 2], dtype=dtypes.float64)
(named_variable,), _, _ = trackable_utils._serialize_object_graph(

View File

@ -23,13 +23,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.module import module
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util as trackable_utils
class NonLayerTrackable(tracking.AutoTrackable):
class NonLayerTrackable(module.Module):
def __init__(self):
super(NonLayerTrackable, self).__init__()