Allow checkpoint loading before mid level TPU embedding API creation. This should avoid double initialization.

PiperOrigin-RevId: 314770483
Change-Id: I33456f712ec197ac03958bbba21b62b0a4cc0be5
This commit is contained in:
Bruce Fontaine 2020-06-04 11:35:57 -07:00 committed by TensorFlower Gardener
parent 48678a1e2d
commit 1af42f028b
3 changed files with 88 additions and 24 deletions

View File

@ -696,17 +696,33 @@ class TPUEmbedding(tracking.AutoTrackable):
"""Create all variables."""
shape = (table.vocabulary_size, table.dim)
# We use functools.partial here for the initial_value so that we have a
# variable creation that is compatible with both the sharded variable
# creator and the normal variable creator. The sharded variable creator
# will extract the shape of the tensor from the functool.partial object to
# decide on the sharding.
parameters = tf_variables.Variable(
name=table.name,
initial_value=functools.partial(
table.initializer, shape=shape, dtype=dtypes.float32),
trainable=not self._using_tpu)
slot_vars = table.optimizer._create_slots(parameters) # pylint: disable=protected-access
def getter(name, shape, dtype, initializer, trainable):
return tf_variables.Variable(
name=name,
initial_value=functools.partial(initializer, shape, dtype=dtype),
trainable=trainable)
def variable_creator(name, initializer, trainable=True):
# use add_variable_with_custom_getter here so that we take advantage of
# the checkpoint loading to allow restore before the variables get
# created which avoids double initialization.
return self._add_variable_with_custom_getter(
name=name,
initializer=initializer,
shape=shape,
dtype=dtypes.float32,
getter=getter,
trainable=trainable)
parameters = variable_creator(table.name, table.initializer,
trainable=not self._using_tpu)
def slot_creator(name, initializer):
return variable_creator(table.name + "/" + name,
initializer,
False)
slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access
slot_vars["parameters"] = parameters
return slot_vars

View File

@ -39,6 +39,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import init_ops_v2
@ -56,7 +57,6 @@ from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest
FLAGS = flags.FLAGS
flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
@ -161,6 +161,60 @@ class TPUEmbeddingCheckpointTest(parameterized.TestCase, test.TestCase):
msg='Second mid level api should have retrieved the first model values.'
)
def test_checkpoint_restore_before_variable_creation(self):
class TestModule(module.Module):
def __init__(self, initializer, rows):
self._initializer = initializer
self._rows = rows
def create_embedding(self):
table = tpu_embedding_v2_utils.TableConfig(
vocabulary_size=self._rows, dim=4, initializer=self._initializer,
combiner='sum', name='table')
feature_config = (tpu_embedding_v2_utils.FeatureConfig(
table=table, name='feature'),)
optimizer = tpu_embedding_v2_utils.SGD()
self.tpu_embedding = tpu_embedding_v2.TPUEmbedding(
feature_config, self._rows, optimizer)
# We need to clear the already loaded config provided by setUp method.
tpu_strategy_util.initialize_tpu_system(self.resolver)
with self.strategy.scope():
module1 = TestModule(init_ops_v2.Ones(),
self.strategy.num_replicas_in_sync * 2)
module1.create_embedding()
checkpoint = util.Checkpoint(test_module=module1)
checkpoint.save(_get_tmpdir('restore_before_create', 'save'))
tpu_strategy_util.initialize_tpu_system(self.resolver)
with self.strategy.scope():
module2 = TestModule(init_ops_v2.Zeros(),
self.strategy.num_replicas_in_sync * 2)
checkpoint = util.Checkpoint(test_module=module2)
checkpoint.restore(_get_tmpdir('restore_before_create', 'save-1'))
with self.strategy.scope():
module2.create_embedding()
def get_values(mid):
return mid._variables['table']['parameters'].variables[0].numpy()
self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)),
get_values(module2.tpu_embedding))
# Fetch the values from the TPU to check that they are the same.
module2.tpu_embedding._retrieve_variables()
self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)),
get_values(module2.tpu_embedding))
def build_mid_level(self, embedding_values, optimizer,
initialize_tpu_embedding=True):
"""Creates an embedding api object initialized to embedding_values."""
@ -172,7 +226,7 @@ class TPUEmbeddingCheckpointTest(parameterized.TestCase, test.TestCase):
feature_config = (tpu_embedding_v2_utils.FeatureConfig(
table=table, name='feature'),)
# batch_size here does not matter as we aren't traininig in any of these
# batch_size here does not matter as we aren't training in any of these
# tests.
return tpu_embedding_v2.TPUEmbedding(
feature_config, 64, optimizer,

View File

@ -20,13 +20,11 @@ from __future__ import print_function
from __future__ import unicode_literals
import abc
import functools
import math
import six
from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.util.tf_export import tf_export
@ -101,13 +99,13 @@ class _Optimizer(object):
"""Returns the retrieve function for the optimizer."""
raise NotImplementedError
def _create_slots(self, table):
def _create_slots(self, table, variable_creator):
"""Creates slot variables for table.
Uses shape of table to create parallel slot variables.
Args:
table: A Variable or equivalent.
table: The table variable to create slots for.
variable_creator: A function which creates variables. Takes parameters
'name', 'initializer'.
Returns:
A dict of variables, keyed by self._slot_names().
@ -118,11 +116,7 @@ class _Optimizer(object):
slots = {}
for slot, initializer in zip(self._slot_names(),
self._slot_initializers()):
slots[slot] = tf_variables.Variable(
name=table.name + "/" + slot,
initial_value=functools.partial(
initializer, shape=table.shape, dtype=table.dtype),
trainable=False)
slots[slot] = variable_creator(name=slot, initializer=initializer)
return slots