From 83a712dc5397a67e69fbdee3ec4b833923fc8727 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Tue, 14 Jul 2020 21:48:19 -0700 Subject: [PATCH] Remove `Autotrackable` from the Keras test files and replacing it with `Module`. PiperOrigin-RevId: 321296275 Change-Id: Idba3ced8d12b163daa254f04535a4542a1d9a45b --- .../keras/tests/convert_to_constants_test.py | 6 ++--- tensorflow/python/keras/tests/saver_test.py | 4 ++-- .../python/keras/tests/tracking_test.py | 23 +++++++++---------- .../python/keras/tests/tracking_util_test.py | 6 ++--- .../tracking_util_with_v1_optimizers_test.py | 6 ++--- .../keras/tests/tracking_util_xla_test.py | 4 ++-- 6 files changed, 24 insertions(+), 25 deletions(-) diff --git a/tensorflow/python/keras/tests/convert_to_constants_test.py b/tensorflow/python/keras/tests/convert_to_constants_test.py index 21081682089..f59c83b79dc 100644 --- a/tensorflow/python/keras/tests/convert_to_constants_test.py +++ b/tensorflow/python/keras/tests/convert_to_constants_test.py @@ -29,12 +29,12 @@ from tensorflow.python.framework import convert_to_constants from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.platform import test from tensorflow.python.saved_model.load import load from tensorflow.python.saved_model.save import save -from tensorflow.python.training.tracking import tracking from tensorflow.python.util import nest @@ -50,7 +50,7 @@ class VariablesToConstantsTest(test.TestCase): root: AutoTrackable object with original ConcreteFunction. output_func: frozen ConcreteFunction. """ - root = tracking.AutoTrackable() + root = module.Module() root.f = model input_func = root.f.get_concrete_function() @@ -91,7 +91,7 @@ class VariablesToConstantsTest(test.TestCase): # Save the converted ConcreteFunction as a signature. save_dir = os.path.join(self.get_temp_dir(), "frozen_saved_model") - root = tracking.AutoTrackable() + root = module.Module() root.f = converted_concrete_func save(root, save_dir, {"mykey": converted_concrete_func}) diff --git a/tensorflow/python/keras/tests/saver_test.py b/tensorflow/python/keras/tests/saver_test.py index 28c65961a53..03496544033 100644 --- a/tensorflow/python/keras/tests/saver_test.py +++ b/tensorflow/python/keras/tests/saver_test.py @@ -27,16 +27,16 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops as ops_lib from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core +from tensorflow.python.module import module from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import adam from tensorflow.python.training import saver as saver_module from tensorflow.python.training import training_util -from tensorflow.python.training.tracking import tracking as trackable_tracking from tensorflow.python.training.tracking import util as trackable_utils -class NonLayerTrackable(trackable_tracking.AutoTrackable): +class NonLayerTrackable(module.Module): def __init__(self): super(NonLayerTrackable, self).__init__() diff --git a/tensorflow/python/keras/tests/tracking_test.py b/tensorflow/python/keras/tests/tracking_test.py index cef5e603dfd..02d5cd519ab 100644 --- a/tensorflow/python/keras/tests/tracking_test.py +++ b/tensorflow/python/keras/tests/tracking_test.py @@ -39,7 +39,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import data_structures -from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import util @@ -290,7 +289,7 @@ class MappingTests(keras_parameterized.TestCase): def testLayerCollectionWithExternalMutation(self): d = {} - root = tracking.AutoTrackable() + root = module.Module() root.wrapper = d self.assertEqual([], root.wrapper.layers) self.assertEqual([], root.wrapper.trainable_weights) @@ -303,7 +302,7 @@ class MappingTests(keras_parameterized.TestCase): self.assertEqual([], root.wrapper.trainable_weights) def testDictWrapperBadKeys(self): - a = tracking.AutoTrackable() + a = module.Module() a.d = {} a.d[1] = data_structures.List() model = training.Model() @@ -313,7 +312,7 @@ class MappingTests(keras_parameterized.TestCase): model.save_weights(save_path) def testDictWrapperNoDependency(self): - a = tracking.AutoTrackable() + a = module.Module() a.d = data_structures.NoDependency({}) a.d[1] = [3] self.assertEqual([a], util.list_objects(a)) @@ -324,7 +323,7 @@ class MappingTests(keras_parameterized.TestCase): model.load_weights(save_path) def testNonStringKeyNotTrackableValue(self): - a = tracking.AutoTrackable() + a = module.Module() a.d = {} a.d["a"] = [3] a.d[1] = data_structures.NoDependency([3]) @@ -338,15 +337,15 @@ class MappingTests(keras_parameterized.TestCase): def testNonAppendNotTrackable(self): # Non-append mutations (deleting or overwriting values) are OK when the # values aren't tracked. - a = tracking.AutoTrackable() + a = module.Module() a.d = {} a.d["a"] = [3] a.d[1] = 3 a.d[1] = 2 self.assertEqual(2, a.d[1]) del a.d[1] - a.d[2] = data_structures.NoDependency(tracking.AutoTrackable()) - second = tracking.AutoTrackable() + a.d[2] = data_structures.NoDependency(module.Module()) + second = module.Module() a.d[2] = data_structures.NoDependency(second) self.assertIs(second, a.d[2]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) @@ -550,10 +549,10 @@ class TupleTests(keras_parameterized.TestCase): class InterfaceTests(keras_parameterized.TestCase): def testNoDependency(self): - root = tracking.AutoTrackable() - hasdep = tracking.AutoTrackable() + root = module.Module() + hasdep = module.Module() root.hasdep = hasdep - nodep = tracking.AutoTrackable() + nodep = module.Module() root.nodep = data_structures.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) @@ -566,7 +565,7 @@ class InterfaceTests(keras_parameterized.TestCase): def __init__(self): super(NoDependencyModel, self).__init__() self.a = [] - self.b = tracking.AutoTrackable() + self.b = module.Module() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps)) diff --git a/tensorflow/python/keras/tests/tracking_util_test.py b/tensorflow/python/keras/tests/tracking_util_test.py index a609d4f711e..32b3ceec6f6 100644 --- a/tensorflow/python/keras/tests/tracking_util_test.py +++ b/tensorflow/python/keras/tests/tracking_util_test.py @@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.module import module from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops @@ -48,7 +49,6 @@ from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util from tensorflow.python.training.tracking import graph_view -from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import util as trackable_utils @@ -68,7 +68,7 @@ class MyModel(training.Model): return ret -class NonLayerTrackable(tracking.AutoTrackable): +class NonLayerTrackable(module.Module): def __init__(self): super(NonLayerTrackable, self).__init__() @@ -709,7 +709,7 @@ class CheckpointingTests(keras_parameterized.TestCase): self.assertEqual(42., self.evaluate(optimizer.beta_1)) -class _ManualScope(tracking.AutoTrackable): +class _ManualScope(module.Module): def __call__(self): with variable_scope.variable_scope("ManualScope") as vs: diff --git a/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py b/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py index 4583616f4d9..1ba76c19866 100644 --- a/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py +++ b/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py @@ -35,6 +35,7 @@ from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core +from tensorflow.python.module import module from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.training import adam @@ -42,11 +43,10 @@ from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util from tensorflow.python.training.tracking import graph_view -from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import util as trackable_utils -class NonLayerTrackable(tracking.AutoTrackable): +class NonLayerTrackable(module.Module): def __init__(self): super(NonLayerTrackable, self).__init__() @@ -460,7 +460,7 @@ class CheckpointingTests(keras_parameterized.TestCase): # pylint: enable=cell-var-from-loop def _get_checkpoint_name(self, name): - root = tracking.AutoTrackable() + root = module.Module() trackable_utils.add_variable( root, name=name, shape=[1, 2], dtype=dtypes.float64) (named_variable,), _, _ = trackable_utils._serialize_object_graph( diff --git a/tensorflow/python/keras/tests/tracking_util_xla_test.py b/tensorflow/python/keras/tests/tracking_util_xla_test.py index 4e8dd0a6fd3..0a311011c5a 100644 --- a/tensorflow/python/keras/tests/tracking_util_xla_test.py +++ b/tensorflow/python/keras/tests/tracking_util_xla_test.py @@ -23,13 +23,13 @@ from tensorflow.python.framework import ops from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.module import module from tensorflow.python.platform import test from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import util as trackable_utils -class NonLayerTrackable(tracking.AutoTrackable): +class NonLayerTrackable(module.Module): def __init__(self): super(NonLayerTrackable, self).__init__()