Fix _warmstart_var_with_vocab to truly accept either a checkpoint dir or a direct path to checkpoint files.
PiperOrigin-RevId: 181341437
This commit is contained in:
parent
9fcaafd4a9
commit
611f18f179
@ -121,7 +121,10 @@ class _WarmStartSettings(
|
|||||||
# where ws could be defined as:
|
# where ws could be defined as:
|
||||||
|
|
||||||
# Warm-start all weights in the model (input layer and hidden weights).
|
# Warm-start all weights in the model (input layer and hidden weights).
|
||||||
|
# Either the directory or a specific checkpoint can be provided (in the case
|
||||||
|
# of the former, the latest checkpoint will be used).
|
||||||
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp")
|
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp")
|
||||||
|
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
|
||||||
|
|
||||||
# Warm-start only the embeddings (input layer).
|
# Warm-start only the embeddings (input layer).
|
||||||
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp",
|
ws = _WarmStartSettings(ckpt_to_initialize_from="/tmp",
|
||||||
@ -348,7 +351,7 @@ def _warmstart_var_with_vocab(var,
|
|||||||
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
|
# TODO(vihanjain): Support _WarmstartSettings where class vocabularies need
|
||||||
# remapping too.
|
# remapping too.
|
||||||
init = checkpoint_ops._load_and_remap_matrix_initializer(
|
init = checkpoint_ops._load_and_remap_matrix_initializer(
|
||||||
ckpt_path=saver.latest_checkpoint(prev_ckpt),
|
ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
|
||||||
old_tensor_name=prev_tensor_name,
|
old_tensor_name=prev_tensor_name,
|
||||||
new_row_vocab_size=current_vocab_size,
|
new_row_vocab_size=current_vocab_size,
|
||||||
new_col_vocab_size=v_shape[1],
|
new_col_vocab_size=v_shape[1],
|
||||||
|
@ -50,9 +50,7 @@ class WarmStartingUtilTest(test.TestCase):
|
|||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
saver = saver_lib.Saver()
|
saver = saver_lib.Saver()
|
||||||
ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
|
ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
|
||||||
ckpt_state_name = "checkpoint"
|
saver.save(sess, ckpt_prefix, global_step=0)
|
||||||
saver.save(
|
|
||||||
sess, ckpt_prefix, global_step=0, latest_filename=ckpt_state_name)
|
|
||||||
|
|
||||||
def _create_prev_run_var(self,
|
def _create_prev_run_var(self,
|
||||||
var_name,
|
var_name,
|
||||||
@ -408,6 +406,44 @@ class WarmStartingUtilTest(test.TestCase):
|
|||||||
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
|
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
|
||||||
sess)
|
sess)
|
||||||
|
|
||||||
|
def testWarmStart_ExplicitCheckpointFile(self):
|
||||||
|
# Create vocab for sparse column "sc_vocab".
|
||||||
|
vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
|
||||||
|
"vocab")
|
||||||
|
# Create feature column.
|
||||||
|
sc_vocab = fc.categorical_column_with_vocabulary_file(
|
||||||
|
"sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
|
||||||
|
|
||||||
|
# Save checkpoint from which to warm-start.
|
||||||
|
_, prev_vocab_val = self._create_prev_run_var(
|
||||||
|
"linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
|
||||||
|
|
||||||
|
partitioner = lambda shape, dtype: [1] * len(shape)
|
||||||
|
# New graph, new session WITHOUT warmstarting.
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
# Without warmstarting, the weights should be initialized using default
|
||||||
|
# initializer (which is init_ops.zeros_initializer).
|
||||||
|
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
|
||||||
|
sess)
|
||||||
|
|
||||||
|
# New graph, new session with warmstarting.
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
with self.test_session(graph=g) as sess:
|
||||||
|
cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
|
||||||
|
# Since old vocab is not explicitly set in WarmStartSettings, the old
|
||||||
|
# vocab is assumed to be same as new vocab.
|
||||||
|
ws_util._warmstart(ws_util._WarmStartSettings(
|
||||||
|
# Explicitly provide the file prefix instead of just the dir.
|
||||||
|
os.path.join(self.get_temp_dir(), "model-0"),
|
||||||
|
vars_to_warmstart=".*sc_vocab.*"))
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
# Verify weights were correctly warmstarted.
|
||||||
|
self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
|
||||||
|
sess)
|
||||||
|
|
||||||
def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
|
def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
|
||||||
# Create old vocabulary, and use a size smaller than the total number of
|
# Create old vocabulary, and use a size smaller than the total number of
|
||||||
# entries.
|
# entries.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user