Allow checkpoint loading before mid level TPU embedding API creation. This should avoid double initialization.
PiperOrigin-RevId: 314770483 Change-Id: I33456f712ec197ac03958bbba21b62b0a4cc0be5
This commit is contained in:
parent
48678a1e2d
commit
1af42f028b
|
@ -696,17 +696,33 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||
"""Create all variables."""
|
||||
shape = (table.vocabulary_size, table.dim)
|
||||
|
||||
# We use functools.partial here for the initial_value so that we have a
|
||||
# variable creation that is compatible with both the sharded variable
|
||||
# creator and the normal variable creator. The sharded variable creator
|
||||
# will extract the shape of the tensor from the functool.partial object to
|
||||
# decide on the sharding.
|
||||
parameters = tf_variables.Variable(
|
||||
name=table.name,
|
||||
initial_value=functools.partial(
|
||||
table.initializer, shape=shape, dtype=dtypes.float32),
|
||||
def getter(name, shape, dtype, initializer, trainable):
|
||||
return tf_variables.Variable(
|
||||
name=name,
|
||||
initial_value=functools.partial(initializer, shape, dtype=dtype),
|
||||
trainable=trainable)
|
||||
|
||||
def variable_creator(name, initializer, trainable=True):
|
||||
# use add_variable_with_custom_getter here so that we take advantage of
|
||||
# the checkpoint loading to allow restore before the variables get
|
||||
# created which avoids double initialization.
|
||||
return self._add_variable_with_custom_getter(
|
||||
name=name,
|
||||
initializer=initializer,
|
||||
shape=shape,
|
||||
dtype=dtypes.float32,
|
||||
getter=getter,
|
||||
trainable=trainable)
|
||||
|
||||
parameters = variable_creator(table.name, table.initializer,
|
||||
trainable=not self._using_tpu)
|
||||
slot_vars = table.optimizer._create_slots(parameters) # pylint: disable=protected-access
|
||||
|
||||
def slot_creator(name, initializer):
|
||||
return variable_creator(table.name + "/" + name,
|
||||
initializer,
|
||||
False)
|
||||
|
||||
slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access
|
||||
slot_vars["parameters"] = parameters
|
||||
return slot_vars
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ from tensorflow.python.framework import dtypes
|
|||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
|
@ -56,7 +57,6 @@ from tensorflow.python.training import checkpoint_utils
|
|||
from tensorflow.python.training.tracking import util
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
|
||||
flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
|
||||
|
@ -161,6 +161,60 @@ class TPUEmbeddingCheckpointTest(parameterized.TestCase, test.TestCase):
|
|||
msg='Second mid level api should have retrieved the first model values.'
|
||||
)
|
||||
|
||||
def test_checkpoint_restore_before_variable_creation(self):
|
||||
|
||||
class TestModule(module.Module):
|
||||
|
||||
def __init__(self, initializer, rows):
|
||||
self._initializer = initializer
|
||||
self._rows = rows
|
||||
|
||||
def create_embedding(self):
|
||||
table = tpu_embedding_v2_utils.TableConfig(
|
||||
vocabulary_size=self._rows, dim=4, initializer=self._initializer,
|
||||
combiner='sum', name='table')
|
||||
feature_config = (tpu_embedding_v2_utils.FeatureConfig(
|
||||
table=table, name='feature'),)
|
||||
optimizer = tpu_embedding_v2_utils.SGD()
|
||||
|
||||
self.tpu_embedding = tpu_embedding_v2.TPUEmbedding(
|
||||
feature_config, self._rows, optimizer)
|
||||
|
||||
# We need to clear the already loaded config provided by setUp method.
|
||||
tpu_strategy_util.initialize_tpu_system(self.resolver)
|
||||
|
||||
with self.strategy.scope():
|
||||
module1 = TestModule(init_ops_v2.Ones(),
|
||||
self.strategy.num_replicas_in_sync * 2)
|
||||
module1.create_embedding()
|
||||
|
||||
checkpoint = util.Checkpoint(test_module=module1)
|
||||
checkpoint.save(_get_tmpdir('restore_before_create', 'save'))
|
||||
|
||||
tpu_strategy_util.initialize_tpu_system(self.resolver)
|
||||
|
||||
with self.strategy.scope():
|
||||
module2 = TestModule(init_ops_v2.Zeros(),
|
||||
self.strategy.num_replicas_in_sync * 2)
|
||||
|
||||
checkpoint = util.Checkpoint(test_module=module2)
|
||||
checkpoint.restore(_get_tmpdir('restore_before_create', 'save-1'))
|
||||
|
||||
with self.strategy.scope():
|
||||
module2.create_embedding()
|
||||
|
||||
def get_values(mid):
|
||||
return mid._variables['table']['parameters'].variables[0].numpy()
|
||||
|
||||
self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)),
|
||||
get_values(module2.tpu_embedding))
|
||||
|
||||
# Fetch the values from the TPU to check that they are the same.
|
||||
module2.tpu_embedding._retrieve_variables()
|
||||
|
||||
self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)),
|
||||
get_values(module2.tpu_embedding))
|
||||
|
||||
def build_mid_level(self, embedding_values, optimizer,
|
||||
initialize_tpu_embedding=True):
|
||||
"""Creates an embedding api object initialized to embedding_values."""
|
||||
|
@ -172,7 +226,7 @@ class TPUEmbeddingCheckpointTest(parameterized.TestCase, test.TestCase):
|
|||
feature_config = (tpu_embedding_v2_utils.FeatureConfig(
|
||||
table=table, name='feature'),)
|
||||
|
||||
# batch_size here does not matter as we aren't traininig in any of these
|
||||
# batch_size here does not matter as we aren't training in any of these
|
||||
# tests.
|
||||
return tpu_embedding_v2.TPUEmbedding(
|
||||
feature_config, 64, optimizer,
|
||||
|
|
|
@ -20,13 +20,11 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import abc
|
||||
import functools
|
||||
import math
|
||||
import six
|
||||
|
||||
from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.tpu.ops import tpu_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
@ -101,13 +99,13 @@ class _Optimizer(object):
|
|||
"""Returns the retrieve function for the optimizer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _create_slots(self, table):
|
||||
def _create_slots(self, table, variable_creator):
|
||||
"""Creates slot variables for table.
|
||||
|
||||
Uses shape of table to create parallel slot variables.
|
||||
|
||||
Args:
|
||||
table: A Variable or equivalent.
|
||||
table: The table variable to create slots for.
|
||||
variable_creator: A function which creates variables. Takes parameters
|
||||
'name', 'initializer'.
|
||||
|
||||
Returns:
|
||||
A dict of variables, keyed by self._slot_names().
|
||||
|
@ -118,11 +116,7 @@ class _Optimizer(object):
|
|||
slots = {}
|
||||
for slot, initializer in zip(self._slot_names(),
|
||||
self._slot_initializers()):
|
||||
slots[slot] = tf_variables.Variable(
|
||||
name=table.name + "/" + slot,
|
||||
initial_value=functools.partial(
|
||||
initializer, shape=table.shape, dtype=table.dtype),
|
||||
trainable=False)
|
||||
slots[slot] = variable_creator(name=slot, initializer=initializer)
|
||||
return slots
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue