diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py index 476776daa8f..37ac8515cb8 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/estimator/warm_starting_util.py @@ -121,7 +121,10 @@ class _WarmStartSettings( # where ws could be defined as: # Warm-start all weights in the model (input layer and hidden weights). + # Either the directory or a specific checkpoint can be provided (in the case + # of the former, the latest checkpoint will be used). ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp") + ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000") # Warm-start only the embeddings (input layer). ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp", @@ -348,7 +351,7 @@ def _warmstart_var_with_vocab(var, # TODO(vihanjain): Support _WarmstartSettings where class vocabularies need # remapping too. init = checkpoint_ops._load_and_remap_matrix_initializer( - ckpt_path=saver.latest_checkpoint(prev_ckpt), + ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), old_tensor_name=prev_tensor_name, new_row_vocab_size=current_vocab_size, new_col_vocab_size=v_shape[1], diff --git a/tensorflow/python/estimator/warm_starting_util_test.py b/tensorflow/python/estimator/warm_starting_util_test.py index cf502dd60de..cc0c4efc756 100644 --- a/tensorflow/python/estimator/warm_starting_util_test.py +++ b/tensorflow/python/estimator/warm_starting_util_test.py @@ -50,9 +50,7 @@ class WarmStartingUtilTest(test.TestCase): sess.run(variables.global_variables_initializer()) saver = saver_lib.Saver() ckpt_prefix = os.path.join(self.get_temp_dir(), "model") - ckpt_state_name = "checkpoint" - saver.save( - sess, ckpt_prefix, global_step=0, latest_filename=ckpt_state_name) + saver.save(sess, ckpt_prefix, global_step=0) def _create_prev_run_var(self, var_name, @@ -408,6 +406,44 @@ class WarmStartingUtilTest(test.TestCase): self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, sess) + def testWarmStart_ExplicitCheckpointFile(self): + # Create vocab for sparse column "sc_vocab". + vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], + "vocab") + # Create feature column. + sc_vocab = fc.categorical_column_with_vocabulary_file( + "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4) + + # Save checkpoint from which to warm-start. + _, prev_vocab_val = self._create_prev_run_var( + "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones()) + + partitioner = lambda shape, dtype: [1] * len(shape) + # New graph, new session WITHOUT warmstarting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + cols_to_vars = self._create_linear_model([sc_vocab], partitioner) + sess.run(variables.global_variables_initializer()) + # Without warmstarting, the weights should be initialized using default + # initializer (which is init_ops.zeros_initializer). + self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]}, + sess) + + # New graph, new session with warmstarting. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + cols_to_vars = self._create_linear_model([sc_vocab], partitioner) + # Since old vocab is not explicitly set in WarmStartSettings, the old + # vocab is assumed to be same as new vocab. + ws_util._warmstart(ws_util._WarmStartSettings( + # Explicitly provide the file prefix instead of just the dir. + os.path.join(self.get_temp_dir(), "model-0"), + vars_to_warmstart=".*sc_vocab.*")) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warmstarted. + self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]}, + sess) + def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self): # Create old vocabulary, and use a size smaller than the total number of # entries.