Move keras model related template test case to keras model test.

PiperOrigin-RevId: 306166092
Change-Id: I798e1039c43795c781f7125e4f6b61f9641cc587
This commit is contained in:
Scott Zhu 2020-04-12 19:40:08 -07:00 committed by TensorFlower Gardener
parent 4e45372967
commit 716bcea5a2
2 changed files with 62 additions and 24 deletions

View File

@ -48,11 +48,14 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import np_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@ -796,6 +799,65 @@ class TrainingTest(keras_parameterized.TestCase):
self.assertEqual(l.non_trainable_variables, [l.layer1.non_trainable_var])
self.assertLen(l.get_weights(), 2)
@keras_parameterized.run_all_keras_modes
def test_weight_tracking_for_template(self):
def variable_scoped_function(trainable=True):
return variable_scope.get_variable(
'dummy', shape=[1], trainable=trainable,
initializer=init_ops.zeros_initializer())
def nested_template():
nested1 = template.make_template('nested', variable_scoped_function)
nested2 = template.make_template('nested', variable_scoped_function)
v1 = nested1()
v2 = nested2()
# nested1 and nested2 should not share variables
self.assertIsNot(v1, v2)
# Variables created by nested1 should be isolated from variables
# created by nested2.
self.assertEqual(1, len(nested1.variables))
self.assertEqual(1, len(nested2.variables))
self.assertIs(nested1.variables[0], v1)
self.assertIs(nested2.variables[0], v2)
self.assertEqual(1, len(nested1.trainable_variables))
self.assertEqual(1, len(nested2.trainable_variables))
self.assertIs(nested1.trainable_variables[0], v1)
self.assertIs(nested2.trainable_variables[0], v2)
self.assertEqual(len(nested1.non_trainable_variables), 0)
self.assertEqual(len(nested2.non_trainable_variables), 0)
return v1, v2
tmpl1 = template.make_template('s1', nested_template)
tmpl2 = template.make_template('s1', nested_template)
v1, v2 = tmpl1()
v5, v6 = tmpl2()
model = training_module.Model()
model.template = tmpl1
self.assertEqual(2, len(model.variables))
self.assertIs(model.variables[0], v1)
self.assertIs(model.variables[1], v2)
self.assertEqual(2, len(model.variables))
self.assertIs(model.trainable_variables[0], v1)
self.assertIs(model.trainable_variables[1], v2)
self.assertEqual(len(model.non_trainable_variables), 0)
model.templates = [tmpl2]
for v, w in zip(model.variables, [v1, v2, v5, v6]):
self.assertIs(v, w)
for v, w in zip(model.trainable_variables, [v1, v2, v5, v6]):
self.assertIs(v, w)
self.assertEqual(len(model.non_trainable_variables), 0)
# Make sure losses, layers, and updates aren't broken by having a Template
# in the mix, which does not expose any updates or losses.
self.assertEqual([], model.layers)
self.assertEqual([], model.updates)
self.assertEqual([], model.losses)
self.assertEqual([], model.templates.layers)
self.assertEqual([], model.templates.updates)
self.assertEqual([], model.templates.losses)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_logs_passed_to_callbacks(self):
input_dim = 5

View File

@ -25,7 +25,6 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@ -390,29 +389,6 @@ class TemplateTest(test.TestCase):
self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
model = training.Model()
model.template = tmpl1
self.assertEqual(2, len(model.variables))
self.assertIs(model.variables[0], v1)
self.assertIs(model.variables[1], v2)
self.assertEqual(2, len(model.variables))
self.assertIs(model.trainable_variables[0], v1)
self.assertIs(model.trainable_variables[1], v2)
self.assertEqual(len(model.non_trainable_variables), 0)
model.templates = [tmpl2]
for v, w in zip(model.variables, [v1, v2, v5, v6]):
self.assertIs(v, w)
for v, w in zip(model.trainable_variables, [v1, v2, v5, v6]):
self.assertIs(v, w)
self.assertEqual(len(model.non_trainable_variables), 0)
# Make sure losses, layers, and updates aren't broken by having a Template
# in the mix, which does not expose any updates or losses.
self.assertEqual([], model.layers)
self.assertEqual([], model.updates)
self.assertEqual([], model.losses)
self.assertEqual([], model.templates.layers)
self.assertEqual([], model.templates.updates)
self.assertEqual([], model.templates.losses)
@test_util.run_in_graph_and_eager_modes
def test_nested_templates_with_defun(self):