Fix _warmstart_var_with_vocab to truly accept either a checkpoint dir or a direct path to checkpoint files.
PiperOrigin-RevId: 181341437
This commit is contained in:
parent
9fcaafd4a9
commit
611f18f179
@ -121,7 +121,10 @@ class _WarmStartSettings(
|
||||
# where ws could be defined as:
|
||||
|
||||
# Warm-start all weights in the model (input layer and hidden weights).
|
||||
# Either the directory or a specific checkpoint can be provided (in the case
|
||||
# of the former, the latest checkpoint will be used).
|
||||
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp")
|
||||
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
|
||||
|
||||
# Warm-start only the embeddings (input layer).
|
||||
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp",
|
||||
@ -348,7 +351,7 @@ def _warmstart_var_with_vocab(var,
|
||||
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
|
||||
# remapping too.
|
||||
init = checkpoint_ops._load_and_remap_matrix_initializer(
|
||||
ckpt_path=saver.latest_checkpoint(prev_ckpt),
|
||||
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
|
||||
old_tensor_name=prev_tensor_name,
|
||||
new_row_vocab_size=current_vocab_size,
|
||||
new_col_vocab_size=v_shape[1],
|
||||
|
@ -50,9 +50,7 @@ class WarmStartingUtilTest(test.TestCase):
|
||||
sess.run(variables.global_variables_initializer())
|
||||
saver = saver_lib.Saver()
|
||||
ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
|
||||
ckpt_state_name = "checkpoint"
|
||||
saver.save(
|
||||
sess, ckpt_prefix, global_step=0, latest_filename=ckpt_state_name)
|
||||
saver.save(sess, ckpt_prefix, global_step=0)
|
||||
|
||||
def _create_prev_run_var(self,
|
||||
var_name,
|
||||
@ -408,6 +406,44 @@ class WarmStartingUtilTest(test.TestCase):
|
||||
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
|
||||
sess)
|
||||
|
||||
def testWarmStart_ExplicitCheckpointFile(self):
|
||||
# Create vocab for sparse column "sc_vocab".
|
||||
vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
||||
"vocab")
|
||||
# Create feature column.
|
||||
sc_vocab = fc.categorical_column_with_vocabulary_file(
|
||||
"sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
|
||||
|
||||
# Save checkpoint from which to warm-start.
|
||||
_, prev_vocab_val = self._create_prev_run_var(
|
||||
"linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
|
||||
|
||||
partitioner = lambda shape, dtype: [1] * len(shape)
|
||||
# New graph, new session WITHOUT warmstarting.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Without warmstarting, the weights should be initialized using default
|
||||
# initializer (which is init_ops.zeros_initializer).
|
||||
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
|
||||
sess)
|
||||
|
||||
# New graph, new session with warmstarting.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as sess:
|
||||
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
|
||||
# Since old vocab is not explicitly set in WarmStartSettings, the old
|
||||
# vocab is assumed to be same as new vocab.
|
||||
ws_util._warmstart(ws_util._WarmStartSettings(
|
||||
# Explicitly provide the file prefix instead of just the dir.
|
||||
os.path.join(self.get_temp_dir(), "model-0"),
|
||||
vars_to_warmstart=".*sc_vocab.*"))
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Verify weights were correctly warmstarted.
|
||||
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
|
||||
sess)
|
||||
|
||||
def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
|
||||
# Create old vocabulary, and use a size smaller than the total number of
|
||||
# entries.
|
||||
|
Loading…
Reference in New Issue
Block a user