Add whitelist support in uid of RunConfig.
Change: 154794859
This commit is contained in:
parent
0287e879ac
commit
9e289ce040
tensorflow/contrib/learn/python/learn
@ -31,6 +31,17 @@ from tensorflow.python.estimator import run_config as core_run_config
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
_DEFAULT_UID_WHITE_LIST = [
|
||||
'tf_random_seed',
|
||||
'save_summary_steps',
|
||||
'save_checkpoints_steps',
|
||||
'save_checkpoints_secs',
|
||||
'session_config',
|
||||
'keep_checkpoint_max',
|
||||
'keep_checkpoint_every_n_hours',
|
||||
]
|
||||
|
||||
|
||||
class Environment(object):
|
||||
# For running general distributed training.
|
||||
CLOUD = 'cloud'
|
||||
@ -312,18 +323,29 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
|
||||
return new_copy
|
||||
|
||||
@experimental
|
||||
def uid(self):
|
||||
def uid(self, whitelist=None):
|
||||
"""Generates a 'Unique Identifier' based on all internal fields.
|
||||
|
||||
Caller should use the uid string to check `RunConfig` instance integrity
|
||||
in one session use, but should not rely on the implementation details, which
|
||||
is subject to change.
|
||||
|
||||
Args:
|
||||
whitelist: A list of the string names of the properties uid should not
|
||||
include. If `None`, defaults to `_DEFAULT_UID_WHITE_LIST`, which
|
||||
includes most properites user allowes to change.
|
||||
|
||||
Returns:
|
||||
A uid string.
|
||||
"""
|
||||
# TODO(b/33295821): Allows user to specify a whitelist.
|
||||
if whitelist is None:
|
||||
whitelist = _DEFAULT_UID_WHITE_LIST
|
||||
|
||||
state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')}
|
||||
# Pop out the keys in whitelist.
|
||||
for k in whitelist:
|
||||
state.pop('_' + k, None)
|
||||
|
||||
ordered_state = collections.OrderedDict(
|
||||
sorted(state.items(), key=lambda t: t[0]))
|
||||
# For class instance without __repr__, some special cares are required.
|
||||
|
@ -257,6 +257,51 @@ class RunConfigTest(test.TestCase):
|
||||
self.assertNotEqual(expected_uid, new_config.uid())
|
||||
self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
|
||||
|
||||
def test_uid_for_whitelist(self):
|
||||
whitelist = ["model_dir"]
|
||||
config = run_config_lib.RunConfig(
|
||||
tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
|
||||
|
||||
expected_uid = config.uid(whitelist)
|
||||
self.assertEqual(expected_uid, config.uid(whitelist))
|
||||
|
||||
new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
|
||||
self.assertEqual(TEST_DIR, config.model_dir)
|
||||
self.assertEqual(expected_uid, new_config.uid(whitelist))
|
||||
self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
|
||||
|
||||
def test_uid_for_default_whitelist(self):
|
||||
config = run_config_lib.RunConfig(
|
||||
tf_random_seed=11,
|
||||
save_summary_steps=12,
|
||||
save_checkpoints_steps=13,
|
||||
save_checkpoints_secs=14,
|
||||
session_config=15,
|
||||
keep_checkpoint_max=16,
|
||||
keep_checkpoint_every_n_hours=17)
|
||||
self.assertEqual(11, config.tf_random_seed)
|
||||
self.assertEqual(12, config.save_summary_steps)
|
||||
self.assertEqual(13, config.save_checkpoints_steps)
|
||||
self.assertEqual(14, config.save_checkpoints_secs)
|
||||
self.assertEqual(15, config.session_config)
|
||||
self.assertEqual(16, config.keep_checkpoint_max)
|
||||
self.assertEqual(17, config.keep_checkpoint_every_n_hours)
|
||||
|
||||
new_config = run_config_lib.RunConfig(
|
||||
tf_random_seed=21,
|
||||
save_summary_steps=22,
|
||||
save_checkpoints_steps=23,
|
||||
save_checkpoints_secs=24,
|
||||
session_config=25,
|
||||
keep_checkpoint_max=26,
|
||||
keep_checkpoint_every_n_hours=27)
|
||||
self.assertEqual(config.uid(), new_config.uid())
|
||||
# model_dir is not on the default whitelist.
|
||||
self.assertNotEqual(config.uid(whitelist=[]),
|
||||
new_config.uid(whitelist=[]))
|
||||
new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR)
|
||||
self.assertNotEqual(config.uid(), new_config.uid())
|
||||
|
||||
def test_uid_for_deepcopy(self):
|
||||
tf_config = {
|
||||
"cluster": {
|
||||
|
@ -293,8 +293,7 @@ class LearnRunnerRunWithRunConfigTest(test.TestCase):
|
||||
def _experiment_fn(run_config, hparams):
|
||||
del run_config, hparams # unused.
|
||||
# Explicitly use a new run_config.
|
||||
new_config = run_config_lib.RunConfig(
|
||||
model_dir=_MODIR_DIR, save_checkpoints_steps=123)
|
||||
new_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR + "/123")
|
||||
|
||||
return TestExperiment(config=new_config)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user