Remove Autotrackable
from the Keras test files and replacing it with Module
.
PiperOrigin-RevId: 321296275 Change-Id: Idba3ced8d12b163daa254f04535a4542a1d9a45b
This commit is contained in:
parent
41339588d9
commit
83a712dc53
@ -29,12 +29,12 @@ from tensorflow.python.framework import convert_to_constants
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.module import module
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model.load import load
|
from tensorflow.python.saved_model.load import load
|
||||||
from tensorflow.python.saved_model.save import save
|
from tensorflow.python.saved_model.save import save
|
||||||
from tensorflow.python.training.tracking import tracking
|
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ class VariablesToConstantsTest(test.TestCase):
|
|||||||
root: AutoTrackable object with original ConcreteFunction.
|
root: AutoTrackable object with original ConcreteFunction.
|
||||||
output_func: frozen ConcreteFunction.
|
output_func: frozen ConcreteFunction.
|
||||||
"""
|
"""
|
||||||
root = tracking.AutoTrackable()
|
root = module.Module()
|
||||||
root.f = model
|
root.f = model
|
||||||
input_func = root.f.get_concrete_function()
|
input_func = root.f.get_concrete_function()
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ class VariablesToConstantsTest(test.TestCase):
|
|||||||
|
|
||||||
# Save the converted ConcreteFunction as a signature.
|
# Save the converted ConcreteFunction as a signature.
|
||||||
save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
|
save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model")
|
||||||
root = tracking.AutoTrackable()
|
root = module.Module()
|
||||||
root.f = converted_concrete_func
|
root.f = converted_concrete_func
|
||||||
save(root, save_dir, {"mykey": converted_concrete_func})
|
save(root, save_dir, {"mykey": converted_concrete_func})
|
||||||
|
|
||||||
|
@ -27,16 +27,16 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops as ops_lib
|
from tensorflow.python.framework import ops as ops_lib
|
||||||
from tensorflow.python.keras.engine import training
|
from tensorflow.python.keras.engine import training
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
|
from tensorflow.python.module import module
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import adam
|
from tensorflow.python.training import adam
|
||||||
from tensorflow.python.training import saver as saver_module
|
from tensorflow.python.training import saver as saver_module
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.training.tracking import tracking as trackable_tracking
|
|
||||||
from tensorflow.python.training.tracking import util as trackable_utils
|
from tensorflow.python.training.tracking import util as trackable_utils
|
||||||
|
|
||||||
|
|
||||||
class NonLayerTrackable(trackable_tracking.AutoTrackable):
|
class NonLayerTrackable(module.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(NonLayerTrackable, self).__init__()
|
super(NonLayerTrackable, self).__init__()
|
||||||
|
@ -39,7 +39,6 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training.tracking import base
|
from tensorflow.python.training.tracking import base
|
||||||
from tensorflow.python.training.tracking import data_structures
|
from tensorflow.python.training.tracking import data_structures
|
||||||
from tensorflow.python.training.tracking import tracking
|
|
||||||
from tensorflow.python.training.tracking import util
|
from tensorflow.python.training.tracking import util
|
||||||
|
|
||||||
|
|
||||||
@ -290,7 +289,7 @@ class MappingTests(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
def testLayerCollectionWithExternalMutation(self):
|
def testLayerCollectionWithExternalMutation(self):
|
||||||
d = {}
|
d = {}
|
||||||
root = tracking.AutoTrackable()
|
root = module.Module()
|
||||||
root.wrapper = d
|
root.wrapper = d
|
||||||
self.assertEqual([], root.wrapper.layers)
|
self.assertEqual([], root.wrapper.layers)
|
||||||
self.assertEqual([], root.wrapper.trainable_weights)
|
self.assertEqual([], root.wrapper.trainable_weights)
|
||||||
@ -303,7 +302,7 @@ class MappingTests(keras_parameterized.TestCase):
|
|||||||
self.assertEqual([], root.wrapper.trainable_weights)
|
self.assertEqual([], root.wrapper.trainable_weights)
|
||||||
|
|
||||||
def testDictWrapperBadKeys(self):
|
def testDictWrapperBadKeys(self):
|
||||||
a = tracking.AutoTrackable()
|
a = module.Module()
|
||||||
a.d = {}
|
a.d = {}
|
||||||
a.d[1] = data_structures.List()
|
a.d[1] = data_structures.List()
|
||||||
model = training.Model()
|
model = training.Model()
|
||||||
@ -313,7 +312,7 @@ class MappingTests(keras_parameterized.TestCase):
|
|||||||
model.save_weights(save_path)
|
model.save_weights(save_path)
|
||||||
|
|
||||||
def testDictWrapperNoDependency(self):
|
def testDictWrapperNoDependency(self):
|
||||||
a = tracking.AutoTrackable()
|
a = module.Module()
|
||||||
a.d = data_structures.NoDependency({})
|
a.d = data_structures.NoDependency({})
|
||||||
a.d[1] = [3]
|
a.d[1] = [3]
|
||||||
self.assertEqual([a], util.list_objects(a))
|
self.assertEqual([a], util.list_objects(a))
|
||||||
@ -324,7 +323,7 @@ class MappingTests(keras_parameterized.TestCase):
|
|||||||
model.load_weights(save_path)
|
model.load_weights(save_path)
|
||||||
|
|
||||||
def testNonStringKeyNotTrackableValue(self):
|
def testNonStringKeyNotTrackableValue(self):
|
||||||
a = tracking.AutoTrackable()
|
a = module.Module()
|
||||||
a.d = {}
|
a.d = {}
|
||||||
a.d["a"] = [3]
|
a.d["a"] = [3]
|
||||||
a.d[1] = data_structures.NoDependency([3])
|
a.d[1] = data_structures.NoDependency([3])
|
||||||
@ -338,15 +337,15 @@ class MappingTests(keras_parameterized.TestCase):
|
|||||||
def testNonAppendNotTrackable(self):
|
def testNonAppendNotTrackable(self):
|
||||||
# Non-append mutations (deleting or overwriting values) are OK when the
|
# Non-append mutations (deleting or overwriting values) are OK when the
|
||||||
# values aren't tracked.
|
# values aren't tracked.
|
||||||
a = tracking.AutoTrackable()
|
a = module.Module()
|
||||||
a.d = {}
|
a.d = {}
|
||||||
a.d["a"] = [3]
|
a.d["a"] = [3]
|
||||||
a.d[1] = 3
|
a.d[1] = 3
|
||||||
a.d[1] = 2
|
a.d[1] = 2
|
||||||
self.assertEqual(2, a.d[1])
|
self.assertEqual(2, a.d[1])
|
||||||
del a.d[1]
|
del a.d[1]
|
||||||
a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
|
a.d[2] = data_structures.NoDependency(module.Module())
|
||||||
second = tracking.AutoTrackable()
|
second = module.Module()
|
||||||
a.d[2] = data_structures.NoDependency(second)
|
a.d[2] = data_structures.NoDependency(second)
|
||||||
self.assertIs(second, a.d[2])
|
self.assertIs(second, a.d[2])
|
||||||
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
|
||||||
@ -550,10 +549,10 @@ class TupleTests(keras_parameterized.TestCase):
|
|||||||
class InterfaceTests(keras_parameterized.TestCase):
|
class InterfaceTests(keras_parameterized.TestCase):
|
||||||
|
|
||||||
def testNoDependency(self):
|
def testNoDependency(self):
|
||||||
root = tracking.AutoTrackable()
|
root = module.Module()
|
||||||
hasdep = tracking.AutoTrackable()
|
hasdep = module.Module()
|
||||||
root.hasdep = hasdep
|
root.hasdep = hasdep
|
||||||
nodep = tracking.AutoTrackable()
|
nodep = module.Module()
|
||||||
root.nodep = data_structures.NoDependency(nodep)
|
root.nodep = data_structures.NoDependency(nodep)
|
||||||
self.assertEqual(1, len(root._checkpoint_dependencies))
|
self.assertEqual(1, len(root._checkpoint_dependencies))
|
||||||
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
|
self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
|
||||||
@ -566,7 +565,7 @@ class InterfaceTests(keras_parameterized.TestCase):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(NoDependencyModel, self).__init__()
|
super(NoDependencyModel, self).__init__()
|
||||||
self.a = []
|
self.a = []
|
||||||
self.b = tracking.AutoTrackable()
|
self.b = module.Module()
|
||||||
|
|
||||||
nodeps = NoDependencyModel()
|
nodeps = NoDependencyModel()
|
||||||
self.assertEqual([nodeps], util.list_objects(nodeps))
|
self.assertEqual([nodeps], util.list_objects(nodeps))
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import sequential
|
|||||||
from tensorflow.python.keras.engine import training
|
from tensorflow.python.keras.engine import training
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.keras.optimizer_v2 import adam
|
from tensorflow.python.keras.optimizer_v2 import adam
|
||||||
|
from tensorflow.python.module import module
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
@ -48,7 +49,6 @@ from tensorflow.python.training import checkpoint_management
|
|||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.training.tracking import graph_view
|
from tensorflow.python.training.tracking import graph_view
|
||||||
from tensorflow.python.training.tracking import tracking
|
|
||||||
from tensorflow.python.training.tracking import util as trackable_utils
|
from tensorflow.python.training.tracking import util as trackable_utils
|
||||||
|
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ class MyModel(training.Model):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class NonLayerTrackable(tracking.AutoTrackable):
|
class NonLayerTrackable(module.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(NonLayerTrackable, self).__init__()
|
super(NonLayerTrackable, self).__init__()
|
||||||
@ -709,7 +709,7 @@ class CheckpointingTests(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(42., self.evaluate(optimizer.beta_1))
|
self.assertEqual(42., self.evaluate(optimizer.beta_1))
|
||||||
|
|
||||||
|
|
||||||
class _ManualScope(tracking.AutoTrackable):
|
class _ManualScope(module.Module):
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
with variable_scope.variable_scope("ManualScope") as vs:
|
with variable_scope.variable_scope("ManualScope") as vs:
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras import combinations
|
|||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras.engine import training
|
from tensorflow.python.keras.engine import training
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
|
from tensorflow.python.module import module
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import adam
|
from tensorflow.python.training import adam
|
||||||
@ -42,11 +43,10 @@ from tensorflow.python.training import checkpoint_management
|
|||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.training.tracking import graph_view
|
from tensorflow.python.training.tracking import graph_view
|
||||||
from tensorflow.python.training.tracking import tracking
|
|
||||||
from tensorflow.python.training.tracking import util as trackable_utils
|
from tensorflow.python.training.tracking import util as trackable_utils
|
||||||
|
|
||||||
|
|
||||||
class NonLayerTrackable(tracking.AutoTrackable):
|
class NonLayerTrackable(module.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(NonLayerTrackable, self).__init__()
|
super(NonLayerTrackable, self).__init__()
|
||||||
@ -460,7 +460,7 @@ class CheckpointingTests(keras_parameterized.TestCase):
|
|||||||
# pylint: enable=cell-var-from-loop
|
# pylint: enable=cell-var-from-loop
|
||||||
|
|
||||||
def _get_checkpoint_name(self, name):
|
def _get_checkpoint_name(self, name):
|
||||||
root = tracking.AutoTrackable()
|
root = module.Module()
|
||||||
trackable_utils.add_variable(
|
trackable_utils.add_variable(
|
||||||
root, name=name, shape=[1, 2], dtype=dtypes.float64)
|
root, name=name, shape=[1, 2], dtype=dtypes.float64)
|
||||||
(named_variable,), _, _ = trackable_utils._serialize_object_graph(
|
(named_variable,), _, _ = trackable_utils._serialize_object_graph(
|
||||||
|
@ -23,13 +23,13 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.keras.engine import training
|
from tensorflow.python.keras.engine import training
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.keras.optimizer_v2 import adam
|
from tensorflow.python.keras.optimizer_v2 import adam
|
||||||
|
from tensorflow.python.module import module
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training.tracking import tracking
|
|
||||||
from tensorflow.python.training.tracking import util as trackable_utils
|
from tensorflow.python.training.tracking import util as trackable_utils
|
||||||
|
|
||||||
|
|
||||||
class NonLayerTrackable(tracking.AutoTrackable):
|
class NonLayerTrackable(module.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(NonLayerTrackable, self).__init__()
|
super(NonLayerTrackable, self).__init__()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user