From 9e289ce04020f01f4d8c537f2c399fddae4be019 Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Mon, 1 May 2017 19:29:12 -0800 Subject: [PATCH] Add whitelist support in uid of RunConfig. Change: 154794859 --- .../python/learn/estimators/run_config.py | 26 ++++++++++- .../learn/estimators/run_config_test.py | 45 +++++++++++++++++++ .../learn/python/learn/learn_runner_test.py | 3 +- 3 files changed, 70 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index 109c8d25e12..5a63ee7fa82 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -31,6 +31,17 @@ from tensorflow.python.estimator import run_config as core_run_config from tensorflow.python.training import server_lib +_DEFAULT_UID_WHITE_LIST = [ + 'tf_random_seed', + 'save_summary_steps', + 'save_checkpoints_steps', + 'save_checkpoints_secs', + 'session_config', + 'keep_checkpoint_max', + 'keep_checkpoint_every_n_hours', +] + + class Environment(object): # For running general distributed training. CLOUD = 'cloud' @@ -312,18 +323,29 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): return new_copy @experimental - def uid(self): + def uid(self, whitelist=None): """Generates a 'Unique Identifier' based on all internal fields. Caller should use the uid string to check `RunConfig` instance integrity in one session use, but should not rely on the implementation details, which is subject to change. + Args: + whitelist: A list of the string names of the properties uid should not + include. If `None`, defaults to `_DEFAULT_UID_WHITE_LIST`, which + includes most properites user allowes to change. + Returns: A uid string. """ - # TODO(b/33295821): Allows user to specify a whitelist. + if whitelist is None: + whitelist = _DEFAULT_UID_WHITE_LIST + state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')} + # Pop out the keys in whitelist. + for k in whitelist: + state.pop('_' + k, None) + ordered_state = collections.OrderedDict( sorted(state.items(), key=lambda t: t[0])) # For class instance without __repr__, some special cares are required. diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py index 14cef7cc43d..6d39a9ad137 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py @@ -257,6 +257,51 @@ class RunConfigTest(test.TestCase): self.assertNotEqual(expected_uid, new_config.uid()) self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir) + def test_uid_for_whitelist(self): + whitelist = ["model_dir"] + config = run_config_lib.RunConfig( + tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR) + + expected_uid = config.uid(whitelist) + self.assertEqual(expected_uid, config.uid(whitelist)) + + new_config = config.replace(model_dir=ANOTHER_TEST_DIR) + self.assertEqual(TEST_DIR, config.model_dir) + self.assertEqual(expected_uid, new_config.uid(whitelist)) + self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir) + + def test_uid_for_default_whitelist(self): + config = run_config_lib.RunConfig( + tf_random_seed=11, + save_summary_steps=12, + save_checkpoints_steps=13, + save_checkpoints_secs=14, + session_config=15, + keep_checkpoint_max=16, + keep_checkpoint_every_n_hours=17) + self.assertEqual(11, config.tf_random_seed) + self.assertEqual(12, config.save_summary_steps) + self.assertEqual(13, config.save_checkpoints_steps) + self.assertEqual(14, config.save_checkpoints_secs) + self.assertEqual(15, config.session_config) + self.assertEqual(16, config.keep_checkpoint_max) + self.assertEqual(17, config.keep_checkpoint_every_n_hours) + + new_config = run_config_lib.RunConfig( + tf_random_seed=21, + save_summary_steps=22, + save_checkpoints_steps=23, + save_checkpoints_secs=24, + session_config=25, + keep_checkpoint_max=26, + keep_checkpoint_every_n_hours=27) + self.assertEqual(config.uid(), new_config.uid()) + # model_dir is not on the default whitelist. + self.assertNotEqual(config.uid(whitelist=[]), + new_config.uid(whitelist=[])) + new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR) + self.assertNotEqual(config.uid(), new_config.uid()) + def test_uid_for_deepcopy(self): tf_config = { "cluster": { diff --git a/tensorflow/contrib/learn/python/learn/learn_runner_test.py b/tensorflow/contrib/learn/python/learn/learn_runner_test.py index 6c8cde453f3..77bdcaeb7ed 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner_test.py @@ -293,8 +293,7 @@ class LearnRunnerRunWithRunConfigTest(test.TestCase): def _experiment_fn(run_config, hparams): del run_config, hparams # unused. # Explicitly use a new run_config. - new_config = run_config_lib.RunConfig( - model_dir=_MODIR_DIR, save_checkpoints_steps=123) + new_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR + "/123") return TestExperiment(config=new_config)