Move keras model related template test case to keras model test.
PiperOrigin-RevId: 306166092 Change-Id: I798e1039c43795c781f7125e4f6b61f9641cc587
This commit is contained in:
parent
4e45372967
commit
716bcea5a2
@ -48,11 +48,14 @@ from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import np_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import template
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
@ -796,6 +799,65 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(l.non_trainable_variables, [l.layer1.non_trainable_var])
|
||||
self.assertLen(l.get_weights(), 2)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_weight_tracking_for_template(self):
|
||||
def variable_scoped_function(trainable=True):
|
||||
return variable_scope.get_variable(
|
||||
'dummy', shape=[1], trainable=trainable,
|
||||
initializer=init_ops.zeros_initializer())
|
||||
def nested_template():
|
||||
nested1 = template.make_template('nested', variable_scoped_function)
|
||||
nested2 = template.make_template('nested', variable_scoped_function)
|
||||
v1 = nested1()
|
||||
v2 = nested2()
|
||||
|
||||
# nested1 and nested2 should not share variables
|
||||
self.assertIsNot(v1, v2)
|
||||
|
||||
# Variables created by nested1 should be isolated from variables
|
||||
# created by nested2.
|
||||
self.assertEqual(1, len(nested1.variables))
|
||||
self.assertEqual(1, len(nested2.variables))
|
||||
self.assertIs(nested1.variables[0], v1)
|
||||
self.assertIs(nested2.variables[0], v2)
|
||||
self.assertEqual(1, len(nested1.trainable_variables))
|
||||
self.assertEqual(1, len(nested2.trainable_variables))
|
||||
self.assertIs(nested1.trainable_variables[0], v1)
|
||||
self.assertIs(nested2.trainable_variables[0], v2)
|
||||
self.assertEqual(len(nested1.non_trainable_variables), 0)
|
||||
self.assertEqual(len(nested2.non_trainable_variables), 0)
|
||||
return v1, v2
|
||||
|
||||
tmpl1 = template.make_template('s1', nested_template)
|
||||
tmpl2 = template.make_template('s1', nested_template)
|
||||
|
||||
v1, v2 = tmpl1()
|
||||
v5, v6 = tmpl2()
|
||||
|
||||
model = training_module.Model()
|
||||
model.template = tmpl1
|
||||
self.assertEqual(2, len(model.variables))
|
||||
self.assertIs(model.variables[0], v1)
|
||||
self.assertIs(model.variables[1], v2)
|
||||
self.assertEqual(2, len(model.variables))
|
||||
self.assertIs(model.trainable_variables[0], v1)
|
||||
self.assertIs(model.trainable_variables[1], v2)
|
||||
self.assertEqual(len(model.non_trainable_variables), 0)
|
||||
model.templates = [tmpl2]
|
||||
for v, w in zip(model.variables, [v1, v2, v5, v6]):
|
||||
self.assertIs(v, w)
|
||||
for v, w in zip(model.trainable_variables, [v1, v2, v5, v6]):
|
||||
self.assertIs(v, w)
|
||||
self.assertEqual(len(model.non_trainable_variables), 0)
|
||||
# Make sure losses, layers, and updates aren't broken by having a Template
|
||||
# in the mix, which does not expose any updates or losses.
|
||||
self.assertEqual([], model.layers)
|
||||
self.assertEqual([], model.updates)
|
||||
self.assertEqual([], model.losses)
|
||||
self.assertEqual([], model.templates.layers)
|
||||
self.assertEqual([], model.templates.updates)
|
||||
self.assertEqual([], model.templates.losses)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_logs_passed_to_callbacks(self):
|
||||
input_dim = 5
|
||||
|
@ -25,7 +25,6 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -390,29 +389,6 @@ class TemplateTest(test.TestCase):
|
||||
self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
|
||||
self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
|
||||
self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
|
||||
model = training.Model()
|
||||
model.template = tmpl1
|
||||
self.assertEqual(2, len(model.variables))
|
||||
self.assertIs(model.variables[0], v1)
|
||||
self.assertIs(model.variables[1], v2)
|
||||
self.assertEqual(2, len(model.variables))
|
||||
self.assertIs(model.trainable_variables[0], v1)
|
||||
self.assertIs(model.trainable_variables[1], v2)
|
||||
self.assertEqual(len(model.non_trainable_variables), 0)
|
||||
model.templates = [tmpl2]
|
||||
for v, w in zip(model.variables, [v1, v2, v5, v6]):
|
||||
self.assertIs(v, w)
|
||||
for v, w in zip(model.trainable_variables, [v1, v2, v5, v6]):
|
||||
self.assertIs(v, w)
|
||||
self.assertEqual(len(model.non_trainable_variables), 0)
|
||||
# Make sure losses, layers, and updates aren't broken by having a Template
|
||||
# in the mix, which does not expose any updates or losses.
|
||||
self.assertEqual([], model.layers)
|
||||
self.assertEqual([], model.updates)
|
||||
self.assertEqual([], model.losses)
|
||||
self.assertEqual([], model.templates.layers)
|
||||
self.assertEqual([], model.templates.updates)
|
||||
self.assertEqual([], model.templates.losses)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_nested_templates_with_defun(self):
|
||||
|
Loading…
Reference in New Issue
Block a user