diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index 5a66f6ce8b9..d82dcc616c6 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -696,17 +696,33 @@ class TPUEmbedding(tracking.AutoTrackable): """Create all variables.""" shape = (table.vocabulary_size, table.dim) - # We use functools.partial here for the initial_value so that we have a - # variable creation that is compatible with both the sharded variable - # creator and the normal variable creator. The sharded variable creator - # will extract the shape of the tensor from the functool.partial object to - # decide on the sharding. - parameters = tf_variables.Variable( - name=table.name, - initial_value=functools.partial( - table.initializer, shape=shape, dtype=dtypes.float32), - trainable=not self._using_tpu) - slot_vars = table.optimizer._create_slots(parameters) # pylint: disable=protected-access + def getter(name, shape, dtype, initializer, trainable): + return tf_variables.Variable( + name=name, + initial_value=functools.partial(initializer, shape, dtype=dtype), + trainable=trainable) + + def variable_creator(name, initializer, trainable=True): + # use add_variable_with_custom_getter here so that we take advantage of + # the checkpoint loading to allow restore before the variables get + # created which avoids double initialization. + return self._add_variable_with_custom_getter( + name=name, + initializer=initializer, + shape=shape, + dtype=dtypes.float32, + getter=getter, + trainable=trainable) + + parameters = variable_creator(table.name, table.initializer, + trainable=not self._using_tpu) + + def slot_creator(name, initializer): + return variable_creator(table.name + "/" + name, + initializer, + False) + + slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access slot_vars["parameters"] = parameters return slot_vars diff --git a/tensorflow/python/tpu/tpu_embedding_v2_test.py b/tensorflow/python/tpu/tpu_embedding_v2_test.py index 78b5c9fa3bc..a8b21480919 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_test.py @@ -39,6 +39,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec +from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import init_ops_v2 @@ -56,7 +57,6 @@ from tensorflow.python.training import checkpoint_utils from tensorflow.python.training.tracking import util from tensorflow.python.util import nest - FLAGS = flags.FLAGS flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.') flags.DEFINE_string('project', None, 'Name of GCP project with TPU.') @@ -161,6 +161,60 @@ class TPUEmbeddingCheckpointTest(parameterized.TestCase, test.TestCase): msg='Second mid level api should have retrieved the first model values.' ) + def test_checkpoint_restore_before_variable_creation(self): + + class TestModule(module.Module): + + def __init__(self, initializer, rows): + self._initializer = initializer + self._rows = rows + + def create_embedding(self): + table = tpu_embedding_v2_utils.TableConfig( + vocabulary_size=self._rows, dim=4, initializer=self._initializer, + combiner='sum', name='table') + feature_config = (tpu_embedding_v2_utils.FeatureConfig( + table=table, name='feature'),) + optimizer = tpu_embedding_v2_utils.SGD() + + self.tpu_embedding = tpu_embedding_v2.TPUEmbedding( + feature_config, self._rows, optimizer) + + # We need to clear the already loaded config provided by setUp method. + tpu_strategy_util.initialize_tpu_system(self.resolver) + + with self.strategy.scope(): + module1 = TestModule(init_ops_v2.Ones(), + self.strategy.num_replicas_in_sync * 2) + module1.create_embedding() + + checkpoint = util.Checkpoint(test_module=module1) + checkpoint.save(_get_tmpdir('restore_before_create', 'save')) + + tpu_strategy_util.initialize_tpu_system(self.resolver) + + with self.strategy.scope(): + module2 = TestModule(init_ops_v2.Zeros(), + self.strategy.num_replicas_in_sync * 2) + + checkpoint = util.Checkpoint(test_module=module2) + checkpoint.restore(_get_tmpdir('restore_before_create', 'save-1')) + + with self.strategy.scope(): + module2.create_embedding() + + def get_values(mid): + return mid._variables['table']['parameters'].variables[0].numpy() + + self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)), + get_values(module2.tpu_embedding)) + + # Fetch the values from the TPU to check that they are the same. + module2.tpu_embedding._retrieve_variables() + + self.assertAllClose(np.ones((self.strategy.num_replicas_in_sync * 2, 4)), + get_values(module2.tpu_embedding)) + def build_mid_level(self, embedding_values, optimizer, initialize_tpu_embedding=True): """Creates an embedding api object initialized to embedding_values.""" @@ -172,7 +226,7 @@ class TPUEmbeddingCheckpointTest(parameterized.TestCase, test.TestCase): feature_config = (tpu_embedding_v2_utils.FeatureConfig( table=table, name='feature'),) - # batch_size here does not matter as we aren't traininig in any of these + # batch_size here does not matter as we aren't training in any of these # tests. return tpu_embedding_v2.TPUEmbedding( feature_config, 64, optimizer, diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils.py b/tensorflow/python/tpu/tpu_embedding_v2_utils.py index bba0d10a62f..9d7de203889 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils.py @@ -20,13 +20,11 @@ from __future__ import print_function from __future__ import unicode_literals import abc -import functools import math import six from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 from tensorflow.python.ops import init_ops_v2 -from tensorflow.python.ops import variables as tf_variables from tensorflow.python.tpu.ops import tpu_ops from tensorflow.python.util.tf_export import tf_export @@ -101,13 +99,13 @@ class _Optimizer(object): """Returns the retrieve function for the optimizer.""" raise NotImplementedError - def _create_slots(self, table): + def _create_slots(self, table, variable_creator): """Creates slot variables for table. - Uses shape of table to create parallel slot variables. - Args: - table: A Variable or equivalent. + table: The table variable to create slots for. + variable_creator: A function which creates variables. Takes parameters + 'name', 'initializer'. Returns: A dict of variables, keyed by self._slot_names(). @@ -118,11 +116,7 @@ class _Optimizer(object): slots = {} for slot, initializer in zip(self._slot_names(), self._slot_initializers()): - slots[slot] = tf_variables.Variable( - name=table.name + "/" + slot, - initial_value=functools.partial( - initializer, shape=table.shape, dtype=table.dtype), - trainable=False) + slots[slot] = variable_creator(name=slot, initializer=initializer) return slots