Add whitelist support in uid of RunConfig.

Change: 154794859
This commit is contained in:
Jianwei Xie 2017-05-01 19:29:12 -08:00 committed by TensorFlower Gardener
parent 0287e879ac
commit 9e289ce040
3 changed files with 70 additions and 4 deletions

View File

@ -31,6 +31,17 @@ from tensorflow.python.estimator import run_config as core_run_config
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
_DEFAULT_UID_WHITE_LIST = [
'tf_random_seed',
'save_summary_steps',
'save_checkpoints_steps',
'save_checkpoints_secs',
'session_config',
'keep_checkpoint_max',
'keep_checkpoint_every_n_hours',
]
class Environment(object): class Environment(object):
# For running general distributed training. # For running general distributed training.
CLOUD = 'cloud' CLOUD = 'cloud'
@ -312,18 +323,29 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
return new_copy return new_copy
@experimental @experimental
def uid(self): def uid(self, whitelist=None):
"""Generates a 'Unique Identifier' based on all internal fields. """Generates a 'Unique Identifier' based on all internal fields.
Caller should use the uid string to check `RunConfig` instance integrity Caller should use the uid string to check `RunConfig` instance integrity
in one session use, but should not rely on the implementation details, which in one session use, but should not rely on the implementation details, which
is subject to change. is subject to change.
Args:
whitelist: A list of the string names of the properties uid should not
include. If `None`, defaults to `_DEFAULT_UID_WHITE_LIST`, which
includes most properites user allowes to change.
Returns: Returns:
A uid string. A uid string.
""" """
# TODO(b/33295821): Allows user to specify a whitelist. if whitelist is None:
whitelist = _DEFAULT_UID_WHITE_LIST
state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')} state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')}
# Pop out the keys in whitelist.
for k in whitelist:
state.pop('_' + k, None)
ordered_state = collections.OrderedDict( ordered_state = collections.OrderedDict(
sorted(state.items(), key=lambda t: t[0])) sorted(state.items(), key=lambda t: t[0]))
# For class instance without __repr__, some special cares are required. # For class instance without __repr__, some special cares are required.

View File

@ -257,6 +257,51 @@ class RunConfigTest(test.TestCase):
self.assertNotEqual(expected_uid, new_config.uid()) self.assertNotEqual(expected_uid, new_config.uid())
self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir) self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
def test_uid_for_whitelist(self):
whitelist = ["model_dir"]
config = run_config_lib.RunConfig(
tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
expected_uid = config.uid(whitelist)
self.assertEqual(expected_uid, config.uid(whitelist))
new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
self.assertEqual(TEST_DIR, config.model_dir)
self.assertEqual(expected_uid, new_config.uid(whitelist))
self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
def test_uid_for_default_whitelist(self):
config = run_config_lib.RunConfig(
tf_random_seed=11,
save_summary_steps=12,
save_checkpoints_steps=13,
save_checkpoints_secs=14,
session_config=15,
keep_checkpoint_max=16,
keep_checkpoint_every_n_hours=17)
self.assertEqual(11, config.tf_random_seed)
self.assertEqual(12, config.save_summary_steps)
self.assertEqual(13, config.save_checkpoints_steps)
self.assertEqual(14, config.save_checkpoints_secs)
self.assertEqual(15, config.session_config)
self.assertEqual(16, config.keep_checkpoint_max)
self.assertEqual(17, config.keep_checkpoint_every_n_hours)
new_config = run_config_lib.RunConfig(
tf_random_seed=21,
save_summary_steps=22,
save_checkpoints_steps=23,
save_checkpoints_secs=24,
session_config=25,
keep_checkpoint_max=26,
keep_checkpoint_every_n_hours=27)
self.assertEqual(config.uid(), new_config.uid())
# model_dir is not on the default whitelist.
self.assertNotEqual(config.uid(whitelist=[]),
new_config.uid(whitelist=[]))
new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR)
self.assertNotEqual(config.uid(), new_config.uid())
def test_uid_for_deepcopy(self): def test_uid_for_deepcopy(self):
tf_config = { tf_config = {
"cluster": { "cluster": {

View File

@ -293,8 +293,7 @@ class LearnRunnerRunWithRunConfigTest(test.TestCase):
def _experiment_fn(run_config, hparams): def _experiment_fn(run_config, hparams):
del run_config, hparams # unused. del run_config, hparams # unused.
# Explicitly use a new run_config. # Explicitly use a new run_config.
new_config = run_config_lib.RunConfig( new_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR + "/123")
model_dir=_MODIR_DIR, save_checkpoints_steps=123)
return TestExperiment(config=new_config) return TestExperiment(config=new_config)