Enable TensorCore embeddings for training via FeatureColumnV2.

PiperOrigin-RevId: 312340625
Change-Id: I559aba797a8f1a37ecec1e4ee71cd027701ae6dd
This commit is contained in:
A. Unique TensorFlower 2020-05-19 13:22:10 -07:00 committed by TensorFlower Gardener
parent b3387c0c19
commit 82143c1ad8
3 changed files with 270 additions and 94 deletions

View File

@ -31,15 +31,18 @@ from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu.feature_column import _is_running_on_cpu
from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name
from tensorflow.python.tpu.feature_column import _SUPPORTED_CATEGORICAL_COLUMNS_V2
from tensorflow.python.tpu.feature_column import _SUPPORTED_SEQUENCE_COLUMNS
from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
_ALLOWED_DEVICES = ['cpu', 'tpu_tensor_core', 'tpu_embedding_core']
_TENSOR_CORE_MASK_KEY_SUFFIX = '__TENSOR_CORE_MASK'
class EmbeddingDevice(enum.Enum):
@ -174,10 +177,13 @@ def embedding_column_v2(categorical_column,
elif embedding_lookup_device == 'tpu_embedding_core':
embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
if (embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE and
not tensor_core_shape):
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
'tensor_core_shape to be set.')
if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE:
if not tensor_core_shape:
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
'tensor_core_shape to be set.')
if isinstance(categorical_column, _SUPPORTED_SEQUENCE_COLUMNS):
raise ValueError('embedding_lookup_device=tpu_tensor_core currently does '
'not support sequence columns.')
if not embedding_lookup_device:
return _TPUEmbeddingColumnV2(
@ -372,10 +378,14 @@ def shared_embedding_columns_v2(categorical_columns,
elif embedding_lookup_device == 'tpu_embedding_core':
embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
if (embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE and
not tensor_core_shape):
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
'tensor_core_shape to be set.')
if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE:
if not tensor_core_shape:
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
'tensor_core_shape to be set.')
for c in sorted_columns:
if isinstance(c, _SUPPORTED_SEQUENCE_COLUMNS):
raise ValueError('embedding_lookup_device=tpu_tensor_core currently '
'does not support sequence columns.')
# Create the state (_SharedEmbeddingColumnLayer) here.
for categorical_column, max_sequence_length in zip(
@ -807,7 +817,13 @@ def sparse_embedding_aggregate_slice(params,
if combiner == 'sum':
return aggregate_emb
elif combiner == 'mean':
return aggregate_emb / math_ops.reduce_sum(values_mask_broadcast, axis=1)
# In the case we have an empty row, both aggregate_emb and
# math_ops.reduce_sum(values_mask_broadcast, axis=1) will be 0. Thus,
# we can take max it with a non-zero value to prevent NaNs. Note that
# math_ops.reduce_sum(values_mask_broadcast, axis=1) will have integer
# values so 1.0 is the smallest value.
return aggregate_emb / math_ops.maximum(
math_ops.reduce_sum(values_mask_broadcast, axis=1), 1.0)
else:
raise ValueError('Dense TPU Embedding does not support combiner '
'other than sum and mean.')
@ -851,6 +867,20 @@ def pad_sparse_embedding_lookup_indices(sparse_indices, padded_size):
return padded_values, padded_mask
def _check_invalid_cases(embedding_lookup_device):
"""Checks for invalid embedding_lookup_device configurations."""
if (tpu.under_tpu_inference_context() and
embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
raise ValueError(
'Using embedding_lookup_device=tpu_embedding_core during inference '
'is not supported.')
if embedding_lookup_device == EmbeddingDevice.CPU:
if not tpu.under_tpu_inference_context():
raise ValueError(
'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" '
'during training is not supported.')
class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2):
"""TPUEmbeddingColumn which allows serving on TensorCore."""
@ -874,46 +904,105 @@ class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2):
del kwargs['embedding_lookup_device']
_TPUEmbeddingColumnV2.__init__(self, *args, **kwargs)
def create_state(self, state_manager):
if (tpu.under_tpu_inference_context() and
self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
raise ValueError(
'Using embedding_lookup_device=tpu_embedding_core during inference '
'is not supported.')
if self._embedding_lookup_device == EmbeddingDevice.CPU:
if tpu.under_tpu_inference_context():
return fc_lib.EmbeddingColumn.create_state(self, state_manager)
else:
raise ValueError(
'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" '
'during training is not supported.')
def __deepcopy__(self, memo):
return _TPUDeviceSpecificEmbeddingColumnV2(
*(copy.deepcopy(a, memo) for a in self.__getnewargs__()),
tensor_core_shape=self._tensor_core_shape,
embedding_lookup_device=self._embedding_lookup_device)
return super(_TPUDeviceSpecificEmbeddingColumnV2,
self).create_state(state_manager)
def create_state(self, state_manager):
_check_invalid_cases(self._embedding_lookup_device)
# CPU case.
if self._embedding_lookup_device == EmbeddingDevice.CPU or _is_running_on_cpu(
):
return fc_lib.EmbeddingColumn.create_state(self, state_manager)
# TPU_EMBEDDING_CORE case.
elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
return super(_TPUDeviceSpecificEmbeddingColumnV2,
self).create_state(state_manager)
# TPU_EMBEDDING_CORE case.
return fc_lib.EmbeddingColumn.create_state(self, state_manager)
def get_dense_tensor(self, transformation_cache, state_manager):
"""Private method that follows get_dense_tensor."""
# If we aren't inferencing on TensorCore, just delegate to parent.
if not tpu.under_tpu_inference_context() or not self._tensor_core_shape:
_check_invalid_cases(self._embedding_lookup_device)
# CPU Case.
if self._embedding_lookup_device == EmbeddingDevice.CPU or _is_running_on_cpu(
):
return super(_TPUDeviceSpecificEmbeddingColumnV2,
self).get_dense_tensor(transformation_cache, state_manager)
# TPU_EMBEDDING_CORE case.
elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
return super(_TPUDeviceSpecificEmbeddingColumnV2,
self).get_dense_tensor(transformation_cache, state_manager)
sparse_tensor = transformation_cache.get(self.categorical_column.name,
state_manager)
# Use outside compile to densify and pad the input tensors.
def host_computation():
return pad_sparse_embedding_lookup_indices(sparse_tensor,
self._tensor_core_shape[1])
# TPU_EMBEDDING_CORE cases.
if tpu.under_tpu_inference_context():
# For inference, use outside compile to densify and pad the input tensors.
sparse_tensor = transformation_cache.get(self.categorical_column.name,
state_manager)
values, mask = tpu.outside_compilation(host_computation)
def host_computation():
return pad_sparse_embedding_lookup_indices(sparse_tensor,
self._tensor_core_shape[1])
# Do a dense embedding lookup on TensorCore.
embedding_weights = state_manager.get_variable(self, 'embedding_weights')
embedding = sparse_embedding_aggregate_slice(embedding_weights,
(values, mask),
self.get_combiner())
return embedding
values, mask = tpu.outside_compilation(host_computation)
else:
# For training, the inputs should already have been densified and padded.
values = transformation_cache.get(self.categorical_column.name,
state_manager)
mask = transformation_cache.get(
self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX,
state_manager)
embedding_weights = state_manager.get_variable(
self, name='embedding_weights')
return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
self.get_combiner())
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
_check_invalid_cases(self._embedding_lookup_device)
# CPU Case.
if self._embedding_lookup_device == EmbeddingDevice.CPU or _is_running_on_cpu(
):
return super(_TPUDeviceSpecificEmbeddingColumnV2,
self)._get_dense_tensor(inputs, weight_collections,
trainable)
# TPU_EMBEDDING_CORE case.
elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
return super(_TPUDeviceSpecificEmbeddingColumnV2,
self)._get_dense_tensor(inputs, weight_collections,
trainable)
# TPU_EMBEDDING_CORE cases.
if tpu.under_tpu_inference_context():
# For inference, use outside compile to densify and pad the input tensors.
sparse_tensor = inputs.get(self.get_feature_key_name())
def host_computation():
return pad_sparse_embedding_lookup_indices(sparse_tensor,
self._tensor_core_shape[1])
values, mask = tpu.outside_compilation(host_computation)
else:
# For training, the inputs should already have been densified and padded.
values = inputs.get(self.get_feature_key_name())
mask = inputs.get(self.get_feature_key_name() +
_TENSOR_CORE_MASK_KEY_SUFFIX)
embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
if (weight_collections and
ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
embedding_weights = variable_scope.get_variable(
name='embedding_weights',
shape=embedding_shape,
dtype=dtypes.float32,
initializer=self.initializer,
trainable=self.trainable and trainable,
collections=weight_collections)
return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
self.get_combiner())
class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2):
@ -940,34 +1029,47 @@ class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2):
del kwargs['embedding_lookup_device']
_TPUSharedEmbeddingColumnV2.__init__(self, *args, **kwargs)
def __deepcopy__(self, memo):
return _TPUSharedDeviceSpecificEmbeddingColumnV2(
*(copy.deepcopy(a, memo) for a in self.__getnewargs__()),
tensor_core_shape=self._tensor_core_shape,
embedding_lookup_device=self._embedding_lookup_device)
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
"""Private method that follows _get_dense_tensor_internal."""
if (tpu.under_tpu_inference_context() and
self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
raise ValueError('Using embedding_lookup_device=tpu_embedding_core '
'during inference is not supported.')
if self._embedding_lookup_device == EmbeddingDevice.CPU:
if tpu.under_tpu_inference_context():
return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
self)._get_dense_tensor_internal(transformation_cache,
state_manager)
else:
raise ValueError(
'Using TPUSharedEmbeddingColumn with '
'embedding_lookup_device="cpu" during training is not supported.')
sparse_tensor = transformation_cache.get(self.categorical_column.name,
state_manager)
_check_invalid_cases(self._embedding_lookup_device)
# CPU Case.
if self._embedding_lookup_device == EmbeddingDevice.CPU or _is_running_on_cpu(
):
return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
self)._get_dense_tensor_internal(transformation_cache,
state_manager)
# TPU_EMBEDDING_CORE case.
if self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
self)._get_dense_tensor_internal(transformation_cache,
state_manager)
# Use outside compile to densify and pad the input tensors.
def host_computation():
return pad_sparse_embedding_lookup_indices(sparse_tensor,
self._tensor_core_shape[1])
# TPU_EMBEDDING_CORE cases.
if tpu.under_tpu_inference_context():
# For inference, use outside compile to densify and pad the input tensors.
sparse_tensor = transformation_cache.get(self.categorical_column.name,
state_manager)
values, mask = tpu.outside_compilation(host_computation)
def host_computation():
return pad_sparse_embedding_lookup_indices(sparse_tensor,
self._tensor_core_shape[1])
values, mask = tpu.outside_compilation(host_computation)
else:
# For training, the inputs should already have been densified and padded.
values = transformation_cache.get(self.categorical_column.name,
state_manager)
mask = transformation_cache.get(
self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX,
state_manager)
# Do a dense embedding lookup on TensorCore.
embedding_weights = self.shared_embedding_column_creator.embedding_weights
embedding = sparse_embedding_aggregate_slice(embedding_weights,
(values, mask),
self.get_combiner())
return embedding
return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
self.get_combiner())

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.tpu import feature_column_v2 as tpu_fc
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_function
def _initialized_session():
@ -514,50 +515,119 @@ class DeviceSpecificEmbeddingColumnTestV2(test.TestCase,
embedding_lookup_device='tpu_tensor_core',
tensor_core_shape=[None, 3])
# Run in TPUInferenceContext so that we hit the intended densification case.
# Run in TPUContexts so that we hit the intended densification case.
context = tpu._TPUInferenceContext('tpu_inference')
context.Enter()
with tpu_function.tpu_shard_context(1):
dense_features = fc_lib.DenseFeatures(embedding_column)
# Sqrtn combiner not supported for now.
if combiner == 'sqrtn':
with self.assertRaisesRegexp(
ValueError, 'Dense TPU Embedding does not support combiner'):
embedding_lookup = dense_features(input_features)
return
if combiner == 'mean':
expected_lookups = (
# example 0:
(7., 11.), # ids [2], embedding = [7, 11]
# example 1:
(2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) =
# [2, 3.5]
)
elif combiner == 'sum':
expected_lookups = (
# example 0:
(7., 11.), # ids [2], embedding = [7, 11]
# example 1:
(4., 7), # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7]
)
dense_features = fc_lib.DenseFeatures(embedding_column)
# Sqrtn combiner not supported for now.
if combiner == 'sqrtn':
with self.assertRaisesRegexp(
ValueError, 'Dense TPU Embedding does not support combiner'):
embedding_lookup = dense_features(input_features)
return
if combiner == 'mean':
embedding_lookup = dense_features(input_features)
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
if shared:
self.assertCountEqual(('inp_shared_embedding:0',),
tuple([v.name for v in global_vars]))
else:
self.assertCountEqual(
('dense_features/inp_embedding/embedding_weights:0',),
tuple([v.name for v in global_vars]))
embedding_var = global_vars[0]
with _initialized_session():
self.assertAllEqual(embedding_values, embedding_var.eval())
eval_res = embedding_lookup.eval()
self.assertAllEqual(expected_lookups, eval_res)
context.Exit()
@test_util.deprecated_graph_mode_only
def test_empty_row(self):
# Inputs.
vocabulary_size = 3
input_sparse_tensor = sparse_tensor.SparseTensorValue(
# example 0, ids []
# example 1, ids [0, 1, 3]
indices=((1, 0), (1, 1), (1, 4)),
values=(0, 1, 3),
dense_shape=(2, 5))
input_features = {'inp': input_sparse_tensor}
# Embedding variable.
embedding_dimension = 2
embedding_values = (
(1., 2.), # id 0
(3., 5.), # id 1
(7., 11.), # id 2
(13., 17.) # id 3
)
def _initializer(shape, dtype, partition_info=None):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
self.assertIsNone(partition_info)
return embedding_values
# Build columns.
categorical_column_input = fc_lib.categorical_column_with_identity(
key='inp', num_buckets=vocabulary_size)
# Set tensor_core_shape to be [None, 20] to ensure some padding and
# dynamic batch size.
embedding_column = tpu_fc.embedding_column_v2(
categorical_column_input,
dimension=embedding_dimension,
initializer=_initializer,
combiner='mean',
embedding_lookup_device='tpu_tensor_core',
tensor_core_shape=[None, 3])
# Run in TPUContexts so that we hit the intended densification case.
context = tpu._TPUInferenceContext('tpu_inference')
context.Enter()
with tpu_function.tpu_shard_context(1):
dense_features = fc_lib.DenseFeatures(embedding_column)
expected_lookups = (
# example 0:
(7., 11.), # ids [2], embedding = [7, 11]
(0., 0.), # ids [], embedding = [0, 0]
# example 1:
(2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
)
elif combiner == 'sum':
expected_lookups = (
# example 0:
(7., 11.), # ids [2], embedding = [7, 11]
# example 1:
(4., 7), # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7]
)
embedding_lookup = dense_features(input_features)
embedding_lookup = dense_features(input_features)
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
if shared:
self.assertCountEqual(('inp_shared_embedding:0',),
tuple([v.name for v in global_vars]))
else:
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertCountEqual(
('dense_features/inp_embedding/embedding_weights:0',),
tuple([v.name for v in global_vars]))
embedding_var = global_vars[0]
with _initialized_session():
self.assertAllEqual(embedding_values, embedding_var.eval())
eval_res = embedding_lookup.eval()
self.assertAllEqual(expected_lookups, eval_res)
context.Exit()
embedding_var = global_vars[0]
with _initialized_session():
self.assertAllEqual(embedding_values, embedding_var.eval())
eval_res = embedding_lookup.eval()
self.assertAllEqual(expected_lookups, eval_res)
context.Exit()
@test_util.deprecated_graph_mode_only
def test_error_dense_shape_invalid(self):

View File

@ -35,6 +35,10 @@ tf_class {
name: "table_to_config_dict"
mtype: "<type \'property\'>"
}
member {
name: "tensor_core_feature_columns"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}