Move keras model related template test case to keras model test.
PiperOrigin-RevId: 306166092 Change-Id: I798e1039c43795c781f7125e4f6b61f9641cc587
This commit is contained in:
parent
4e45372967
commit
716bcea5a2
@ -48,11 +48,14 @@ from tensorflow.python.keras.engine import training_utils
|
|||||||
from tensorflow.python.keras.utils import data_utils
|
from tensorflow.python.keras.utils import data_utils
|
||||||
from tensorflow.python.keras.utils import np_utils
|
from tensorflow.python.keras.utils import np_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import template
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||||
@ -796,6 +799,65 @@ class TrainingTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(l.non_trainable_variables, [l.layer1.non_trainable_var])
|
self.assertEqual(l.non_trainable_variables, [l.layer1.non_trainable_var])
|
||||||
self.assertLen(l.get_weights(), 2)
|
self.assertLen(l.get_weights(), 2)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_weight_tracking_for_template(self):
|
||||||
|
def variable_scoped_function(trainable=True):
|
||||||
|
return variable_scope.get_variable(
|
||||||
|
'dummy', shape=[1], trainable=trainable,
|
||||||
|
initializer=init_ops.zeros_initializer())
|
||||||
|
def nested_template():
|
||||||
|
nested1 = template.make_template('nested', variable_scoped_function)
|
||||||
|
nested2 = template.make_template('nested', variable_scoped_function)
|
||||||
|
v1 = nested1()
|
||||||
|
v2 = nested2()
|
||||||
|
|
||||||
|
# nested1 and nested2 should not share variables
|
||||||
|
self.assertIsNot(v1, v2)
|
||||||
|
|
||||||
|
# Variables created by nested1 should be isolated from variables
|
||||||
|
# created by nested2.
|
||||||
|
self.assertEqual(1, len(nested1.variables))
|
||||||
|
self.assertEqual(1, len(nested2.variables))
|
||||||
|
self.assertIs(nested1.variables[0], v1)
|
||||||
|
self.assertIs(nested2.variables[0], v2)
|
||||||
|
self.assertEqual(1, len(nested1.trainable_variables))
|
||||||
|
self.assertEqual(1, len(nested2.trainable_variables))
|
||||||
|
self.assertIs(nested1.trainable_variables[0], v1)
|
||||||
|
self.assertIs(nested2.trainable_variables[0], v2)
|
||||||
|
self.assertEqual(len(nested1.non_trainable_variables), 0)
|
||||||
|
self.assertEqual(len(nested2.non_trainable_variables), 0)
|
||||||
|
return v1, v2
|
||||||
|
|
||||||
|
tmpl1 = template.make_template('s1', nested_template)
|
||||||
|
tmpl2 = template.make_template('s1', nested_template)
|
||||||
|
|
||||||
|
v1, v2 = tmpl1()
|
||||||
|
v5, v6 = tmpl2()
|
||||||
|
|
||||||
|
model = training_module.Model()
|
||||||
|
model.template = tmpl1
|
||||||
|
self.assertEqual(2, len(model.variables))
|
||||||
|
self.assertIs(model.variables[0], v1)
|
||||||
|
self.assertIs(model.variables[1], v2)
|
||||||
|
self.assertEqual(2, len(model.variables))
|
||||||
|
self.assertIs(model.trainable_variables[0], v1)
|
||||||
|
self.assertIs(model.trainable_variables[1], v2)
|
||||||
|
self.assertEqual(len(model.non_trainable_variables), 0)
|
||||||
|
model.templates = [tmpl2]
|
||||||
|
for v, w in zip(model.variables, [v1, v2, v5, v6]):
|
||||||
|
self.assertIs(v, w)
|
||||||
|
for v, w in zip(model.trainable_variables, [v1, v2, v5, v6]):
|
||||||
|
self.assertIs(v, w)
|
||||||
|
self.assertEqual(len(model.non_trainable_variables), 0)
|
||||||
|
# Make sure losses, layers, and updates aren't broken by having a Template
|
||||||
|
# in the mix, which does not expose any updates or losses.
|
||||||
|
self.assertEqual([], model.layers)
|
||||||
|
self.assertEqual([], model.updates)
|
||||||
|
self.assertEqual([], model.losses)
|
||||||
|
self.assertEqual([], model.templates.layers)
|
||||||
|
self.assertEqual([], model.templates.updates)
|
||||||
|
self.assertEqual([], model.templates.losses)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
def test_logs_passed_to_callbacks(self):
|
def test_logs_passed_to_callbacks(self):
|
||||||
input_dim = 5
|
input_dim = 5
|
||||||
|
@ -25,7 +25,6 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.engine import training
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -390,29 +389,6 @@ class TemplateTest(test.TestCase):
|
|||||||
self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
|
self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
|
||||||
self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
|
self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
|
||||||
self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
|
self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
|
||||||
model = training.Model()
|
|
||||||
model.template = tmpl1
|
|
||||||
self.assertEqual(2, len(model.variables))
|
|
||||||
self.assertIs(model.variables[0], v1)
|
|
||||||
self.assertIs(model.variables[1], v2)
|
|
||||||
self.assertEqual(2, len(model.variables))
|
|
||||||
self.assertIs(model.trainable_variables[0], v1)
|
|
||||||
self.assertIs(model.trainable_variables[1], v2)
|
|
||||||
self.assertEqual(len(model.non_trainable_variables), 0)
|
|
||||||
model.templates = [tmpl2]
|
|
||||||
for v, w in zip(model.variables, [v1, v2, v5, v6]):
|
|
||||||
self.assertIs(v, w)
|
|
||||||
for v, w in zip(model.trainable_variables, [v1, v2, v5, v6]):
|
|
||||||
self.assertIs(v, w)
|
|
||||||
self.assertEqual(len(model.non_trainable_variables), 0)
|
|
||||||
# Make sure losses, layers, and updates aren't broken by having a Template
|
|
||||||
# in the mix, which does not expose any updates or losses.
|
|
||||||
self.assertEqual([], model.layers)
|
|
||||||
self.assertEqual([], model.updates)
|
|
||||||
self.assertEqual([], model.losses)
|
|
||||||
self.assertEqual([], model.templates.layers)
|
|
||||||
self.assertEqual([], model.templates.updates)
|
|
||||||
self.assertEqual([], model.templates.losses)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_nested_templates_with_defun(self):
|
def test_nested_templates_with_defun(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user