Add whitelist support in uid of RunConfig.

Change: 154794859
This commit is contained in:
Jianwei Xie 2017-05-01 19:29:12 -08:00 committed by TensorFlower Gardener
parent 0287e879ac
commit 9e289ce040
3 changed files with 70 additions and 4 deletions
tensorflow/contrib/learn/python/learn

View File

@ -31,6 +31,17 @@ from tensorflow.python.estimator import run_config as core_run_config
from tensorflow.python.training import server_lib
_DEFAULT_UID_WHITE_LIST = [
'tf_random_seed',
'save_summary_steps',
'save_checkpoints_steps',
'save_checkpoints_secs',
'session_config',
'keep_checkpoint_max',
'keep_checkpoint_every_n_hours',
]
class Environment(object):
# For running general distributed training.
CLOUD = 'cloud'
@ -312,18 +323,29 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
return new_copy
@experimental
def uid(self):
def uid(self, whitelist=None):
"""Generates a 'Unique Identifier' based on all internal fields.
Caller should use the uid string to check `RunConfig` instance integrity
in one session use, but should not rely on the implementation details, which
is subject to change.
Args:
whitelist: A list of the string names of the properties uid should not
include. If `None`, defaults to `_DEFAULT_UID_WHITE_LIST`, which
includes most properites user allowes to change.
Returns:
A uid string.
"""
# TODO(b/33295821): Allows user to specify a whitelist.
if whitelist is None:
whitelist = _DEFAULT_UID_WHITE_LIST
state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')}
# Pop out the keys in whitelist.
for k in whitelist:
state.pop('_' + k, None)
ordered_state = collections.OrderedDict(
sorted(state.items(), key=lambda t: t[0]))
# For class instance without __repr__, some special cares are required.

View File

@ -257,6 +257,51 @@ class RunConfigTest(test.TestCase):
self.assertNotEqual(expected_uid, new_config.uid())
self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
def test_uid_for_whitelist(self):
whitelist = ["model_dir"]
config = run_config_lib.RunConfig(
tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
expected_uid = config.uid(whitelist)
self.assertEqual(expected_uid, config.uid(whitelist))
new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
self.assertEqual(TEST_DIR, config.model_dir)
self.assertEqual(expected_uid, new_config.uid(whitelist))
self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
def test_uid_for_default_whitelist(self):
config = run_config_lib.RunConfig(
tf_random_seed=11,
save_summary_steps=12,
save_checkpoints_steps=13,
save_checkpoints_secs=14,
session_config=15,
keep_checkpoint_max=16,
keep_checkpoint_every_n_hours=17)
self.assertEqual(11, config.tf_random_seed)
self.assertEqual(12, config.save_summary_steps)
self.assertEqual(13, config.save_checkpoints_steps)
self.assertEqual(14, config.save_checkpoints_secs)
self.assertEqual(15, config.session_config)
self.assertEqual(16, config.keep_checkpoint_max)
self.assertEqual(17, config.keep_checkpoint_every_n_hours)
new_config = run_config_lib.RunConfig(
tf_random_seed=21,
save_summary_steps=22,
save_checkpoints_steps=23,
save_checkpoints_secs=24,
session_config=25,
keep_checkpoint_max=26,
keep_checkpoint_every_n_hours=27)
self.assertEqual(config.uid(), new_config.uid())
# model_dir is not on the default whitelist.
self.assertNotEqual(config.uid(whitelist=[]),
new_config.uid(whitelist=[]))
new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR)
self.assertNotEqual(config.uid(), new_config.uid())
def test_uid_for_deepcopy(self):
tf_config = {
"cluster": {

View File

@ -293,8 +293,7 @@ class LearnRunnerRunWithRunConfigTest(test.TestCase):
def _experiment_fn(run_config, hparams):
del run_config, hparams # unused.
# Explicitly use a new run_config.
new_config = run_config_lib.RunConfig(
model_dir=_MODIR_DIR, save_checkpoints_steps=123)
new_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR + "/123")
return TestExperiment(config=new_config)