Fix _warmstart_var_with_vocab to truly accept either a checkpoint dir or a direct path to checkpoint files.

PiperOrigin-RevId: 181341437
This commit is contained in:
A. Unique TensorFlower 2018-01-09 10:49:16 -08:00 committed by TensorFlower Gardener
parent 9fcaafd4a9
commit 611f18f179
2 changed files with 43 additions and 4 deletions

View File

@ -121,7 +121,10 @@ class _WarmStartSettings(
# where ws could be defined as:
# Warm-start all weights in the model (input layer and hidden weights).
# Either the directory or a specific checkpoint can be provided (in the case
# of the former, the latest checkpoint will be used).
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp")
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
# Warm-start only the embeddings (input layer).
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp",
@ -348,7 +351,7 @@ def _warmstart_var_with_vocab(var,
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
# remapping too.
init = checkpoint_ops._load_and_remap_matrix_initializer(
ckpt_path=saver.latest_checkpoint(prev_ckpt),
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
old_tensor_name=prev_tensor_name,
new_row_vocab_size=current_vocab_size,
new_col_vocab_size=v_shape[1],

View File

@ -50,9 +50,7 @@ class WarmStartingUtilTest(test.TestCase):
sess.run(variables.global_variables_initializer())
saver = saver_lib.Saver()
ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
ckpt_state_name = "checkpoint"
saver.save(
sess, ckpt_prefix, global_step=0, latest_filename=ckpt_state_name)
saver.save(sess, ckpt_prefix, global_step=0)
def _create_prev_run_var(self,
var_name,
@ -408,6 +406,44 @@ class WarmStartingUtilTest(test.TestCase):
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
sess)
def testWarmStart_ExplicitCheckpointFile(self):
# Create vocab for sparse column "sc_vocab".
vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
"vocab")
# Create feature column.
sc_vocab = fc.categorical_column_with_vocabulary_file(
"sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
# Save checkpoint from which to warm-start.
_, prev_vocab_val = self._create_prev_run_var(
"linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
partitioner = lambda shape, dtype: [1] * len(shape)
# New graph, new session WITHOUT warmstarting.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
sess.run(variables.global_variables_initializer())
# Without warmstarting, the weights should be initialized using default
# initializer (which is init_ops.zeros_initializer).
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
sess)
# New graph, new session with warmstarting.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
# Since old vocab is not explicitly set in WarmStartSettings, the old
# vocab is assumed to be same as new vocab.
ws_util._warmstart(ws_util._WarmStartSettings(
# Explicitly provide the file prefix instead of just the dir.
os.path.join(self.get_temp_dir(), "model-0"),
vars_to_warmstart=".*sc_vocab.*"))
sess.run(variables.global_variables_initializer())
# Verify weights were correctly warmstarted.
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
sess)
def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
# Create old vocabulary, and use a size smaller than the total number of
# entries.