Add support for tensorcore serving path for feature columns.

PiperOrigin-RevId: 272716932
This commit is contained in:
A. Unique TensorFlower 2019-10-03 12:31:22 -07:00 committed by TensorFlower Gardener
parent e3e842b6c5
commit fbdc707d14
3 changed files with 551 additions and 30 deletions

View File

@ -19,10 +19,17 @@ from __future__ import print_function
import math
import enum
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.feature_column import feature_column_lib as fc_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu.feature_column import _is_running_on_cpu
from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name
@ -31,6 +38,14 @@ from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
_ALLOWED_DEVICES = ['cpu', 'tpu_tensor_core', 'tpu_embedding_core']
class EmbeddingDevice(enum.Enum):
CPU = 1
TPU_TENSOR_CORE = 2
TPU_EMBEDDING_CORE = 3
@tf_export(v1=['tpu.experimental.embedding_column'])
def embedding_column_v2(categorical_column,
@ -38,7 +53,9 @@ def embedding_column_v2(categorical_column,
combiner='mean',
initializer=None,
max_sequence_length=0,
learning_rate_fn=None):
learning_rate_fn=None,
embedding_lookup_device=None,
tensor_core_shape=None):
"""TPU version of `tf.compat.v1.feature_column.embedding_column`.
Note that the interface for `tf.tpu.experimental.embedding_column` is
@ -89,6 +106,21 @@ def embedding_column_v2(categorical_column,
sequence features and 0 for non-sequence features.
learning_rate_fn: A function that takes global step and returns learning
rate for the embedding table.
embedding_lookup_device: The device on which to run the embedding lookup.
Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core".
If specifying "tpu_tensor_core", a tensor_core_shape must be supplied.
If not specified, the default behavior is embedding lookup on
"tpu_embedding_core" for training and "cpu" for inference.
Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"]
Valid options for serving : ["cpu", "tpu_tensor_core"]
For training, tpu_embedding_core is good for large embedding vocab (>1M),
otherwise, tpu_tensor_core is often sufficient.
For serving, doing embedding lookup on tpu_tensor_core during serving is
a way to reduce host cpu usage in cases where that is a bottleneck.
tensor_core_shape: If supplied, a list of integers which specifies
the intended dense shape to run embedding lookup for this feature on
TensorCore. The batch dimension can be left None or -1 to indicate
a dynamic shape. Only rank 2 shapes currently supported.
Returns:
A `_TPUEmbeddingColumnV2`.
@ -106,6 +138,9 @@ def embedding_column_v2(categorical_column,
]), type(categorical_column)))
if (dimension is None) or (dimension < 1):
raise ValueError('Invalid dimension {}.'.format(dimension))
if tensor_core_shape and len(tensor_core_shape) != 2:
raise ValueError(
'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape))
if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified. '
@ -115,14 +150,41 @@ def embedding_column_v2(categorical_column,
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1 / math.sqrt(dimension))
column = _TPUEmbeddingColumnV2(
categorical_column=categorical_column,
dimension=dimension,
combiner=combiner,
initializer=initializer,
max_sequence_length=max_sequence_length,
learning_rate_fn=learning_rate_fn)
return column
if (embedding_lookup_device and
embedding_lookup_device not in _ALLOWED_DEVICES):
raise ValueError('If set, embedding_lookup_device must be in ',
_ALLOWED_DEVICES)
if embedding_lookup_device == 'cpu':
embedding_lookup_device = EmbeddingDevice.CPU
elif embedding_lookup_device == 'tpu_tensor_core':
embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE
elif embedding_lookup_device == 'tpu_embedding_core':
embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
if (embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE and
not tensor_core_shape):
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
'tensor_core_shape to be set.')
if not embedding_lookup_device:
return _TPUEmbeddingColumnV2(
categorical_column=categorical_column,
dimension=dimension,
combiner=combiner,
initializer=initializer,
max_sequence_length=max_sequence_length,
learning_rate_fn=learning_rate_fn)
else:
return _TPUDeviceSpecificEmbeddingColumnV2(
categorical_column=categorical_column,
dimension=dimension,
combiner=combiner,
initializer=initializer,
max_sequence_length=max_sequence_length,
learning_rate_fn=learning_rate_fn,
embedding_lookup_device=embedding_lookup_device,
tensor_core_shape=tensor_core_shape)
@tf_export(v1=['tpu.experimental.shared_embedding_columns'])
@ -132,7 +194,9 @@ def shared_embedding_columns_v2(categorical_columns,
initializer=None,
shared_embedding_collection_name=None,
max_sequence_lengths=None,
learning_rate_fn=None):
learning_rate_fn=None,
embedding_lookup_device=None,
tensor_core_shape=None):
"""TPU version of `tf.compat.v1.feature_column.shared_embedding_columns`.
Note that the interface for `tf.tpu.experimental.shared_embedding_columns` is
@ -165,15 +229,15 @@ def shared_embedding_columns_v2(categorical_columns,
Args:
categorical_columns: A list of categorical columns returned from
`categorical_column_with_identity`, `weighted_categorical_column`,
`categorical_column_with_vocabulary_file`,
`categorical_column_with_vocabulary_list`,
`sequence_categorical_column_with_identity`,
`sequence_categorical_column_with_vocabulary_file`,
`sequence_categorical_column_with_vocabulary_list`
`categorical_column_with_identity`, `weighted_categorical_column`,
`categorical_column_with_vocabulary_file`,
`categorical_column_with_vocabulary_list`,
`sequence_categorical_column_with_identity`,
`sequence_categorical_column_with_vocabulary_file`,
`sequence_categorical_column_with_vocabulary_list`
dimension: An integer specifying dimension of the embedding, must be > 0.
combiner: A string specifying how to reduce if there are multiple entries
in a single row for a non-sequence column. For more information, see
combiner: A string specifying how to reduce if there are multiple entries in
a single row for a non-sequence column. For more information, see
`tf.feature_column.embedding_column`.
initializer: A variable initializer function to be used in embedding
variable initialization. If not specified, defaults to
@ -183,14 +247,29 @@ def shared_embedding_columns_v2(categorical_columns,
shared embedding weights are added. If not given, a reasonable name will
be chosen based on the names of `categorical_columns`. This is also used
in `variable_scope` when creating shared embedding weights.
max_sequence_lengths: An list of non-negative integers, either None or
empty or the same length as the argument categorical_columns. Entries
max_sequence_lengths: An list of non-negative integers, either None or empty
or the same length as the argument categorical_columns. Entries
corresponding to non-sequence columns must be 0 and entries corresponding
to sequence columns specify the max sequence length for the column. Any
sequence shorter then this will be padded with 0 embeddings and any
sequence longer will be truncated.
learning_rate_fn: A function that takes global step and returns learning
rate for the embedding table.
embedding_lookup_device: The device on which to run the embedding lookup.
Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". If
specifying "tpu_tensor_core", a tensor_core_shape must be supplied.
Defaults to "cpu". If not specified, the default behavior is embedding
lookup on "tpu_embedding_core" for training and "cpu" for inference.
Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"]
Valid options for serving : ["cpu", "tpu_tensor_core"]
For training, tpu_embedding_core is good for large embedding vocab (>1M),
otherwise, tpu_tensor_core is often sufficient.
For serving, doing embedding lookup on tpu_tensor_core during serving is
a way to reduce host cpu usage in cases where that is a bottleneck.
tensor_core_shape: If supplied, a list of integers which specifies the
intended dense shape to run embedding lookup for this feature on
TensorCore. The batch dimension can be left None or -1 to indicate a
dynamic shape. Only rank 2 shapes currently supported.
Returns:
A list of `_TPUSharedEmbeddingColumnV2`.
@ -222,6 +301,9 @@ def shared_embedding_columns_v2(categorical_columns,
if (dimension is None) or (dimension < 1):
raise ValueError('Invalid dimension {}.'.format(dimension))
if tensor_core_shape and len(tensor_core_shape) != 2:
raise ValueError(
'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape))
if (initializer is not None) and (not callable(initializer)):
raise ValueError('initializer must be callable if specified. ')
@ -253,17 +335,46 @@ def shared_embedding_columns_v2(categorical_columns,
tensor_name_in_ckpt=None, num_buckets=num_buckets, trainable=None,
name=shared_embedding_collection_name)
if (embedding_lookup_device and
embedding_lookup_device not in _ALLOWED_DEVICES):
raise ValueError('If set, embedding_lookup_device must be in ',
_ALLOWED_DEVICES)
if embedding_lookup_device == 'cpu':
embedding_lookup_device = EmbeddingDevice.CPU
elif embedding_lookup_device == 'tpu_tensor_core':
embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE
elif embedding_lookup_device == 'tpu_embedding_core':
embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
if (embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE and
not tensor_core_shape):
raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
'tensor_core_shape to be set.')
# Create the state (_SharedEmbeddingColumnLayer) here.
for categorical_column, max_sequence_length in zip(
categorical_columns, max_sequence_lengths):
column = _TPUSharedEmbeddingColumnV2(
categorical_column=categorical_column,
shared_embedding_column_creator=column_creator,
combiner=combiner,
initializer=initializer,
shared_embedding_collection_name=shared_embedding_collection_name,
max_sequence_length=max_sequence_length,
learning_rate_fn=learning_rate_fn)
if not embedding_lookup_device:
column = _TPUSharedEmbeddingColumnV2(
categorical_column=categorical_column,
shared_embedding_column_creator=column_creator,
combiner=combiner,
initializer=initializer,
shared_embedding_collection_name=shared_embedding_collection_name,
max_sequence_length=max_sequence_length,
learning_rate_fn=learning_rate_fn)
else:
column = _TPUSharedDeviceSpecificEmbeddingColumnV2(
categorical_column=categorical_column,
shared_embedding_column_creator=column_creator,
combiner=combiner,
initializer=initializer,
shared_embedding_collection_name=shared_embedding_collection_name,
max_sequence_length=max_sequence_length,
learning_rate_fn=learning_rate_fn,
embedding_lookup_device=embedding_lookup_device,
tensor_core_shape=tensor_core_shape)
tpu_columns.append(column)
return tpu_columns
@ -601,3 +712,211 @@ def split_sequence_columns_v2(feature_columns):
else:
non_sequence_columns.append(column)
return sequence_columns, non_sequence_columns
def sparse_embedding_aggregate_slice(params,
values_and_values_mask,
combiner='mean',
name='sparse_embedding_aggregate_slice'):
"""Uses XLA's dynamic slice operations to perform embedding lookups.
From third_party/cloud_tpu/models/movielens/tpu_embedding.py
Args:
params: Tensor of embedding table. Rank 2 (table_size x embedding dim)
values_and_values_mask: is a two-tuple that contains: values - Tensor of
embedding indices. Rank 2 (batch x n_indices) values_mask - Tensor of mask
/ weights. Rank 2 (batch x n_indices)
combiner: The combiner to use for the embedding lookup. Currently supports
'sum' and 'mean'.
name: Optional name scope for created ops
Returns:
Rank 2 tensor of aggregated (per batch element) embedding vectors.
Raises:
ValueError: Combiner is not supported.
"""
values, values_mask = values_and_values_mask # unpack the two-tuple
with ops.name_scope(name):
_, embedding_dimension = params.get_shape().as_list()
n_batch, n_indices_padded = values.get_shape().as_list()
if not n_batch:
n_batch = -1
emb_lookup = array_ops.reshape(
embedding_ops.embedding_lookup(
params, array_ops.reshape(values, [n_batch, n_indices_padded])),
[n_batch, n_indices_padded, embedding_dimension])
values_mask_broadcast = array_ops.reshape(values_mask,
[n_batch, n_indices_padded, 1])
aggregate_emb = math_ops.reduce_sum(
emb_lookup * values_mask_broadcast, axis=1)
if combiner == 'sum':
return aggregate_emb
elif combiner == 'mean':
return aggregate_emb / math_ops.reduce_sum(values_mask_broadcast, axis=1)
else:
raise ValueError('Dense TPU Embedding does not support combiner '
'other than sum and mean.')
def pad_sparse_embedding_lookup_indices(sparse_indices, padded_size):
"""Creates statically-sized Tensors containing indices and weights.
From third_party/cloud_tpu/models/movielens/tpu_embedding.py
Also computes sparse_indices.values % embedding_table_size, for equivalent
functionality to sparse_column_with_integerized_feature. The returned
padded weight Tensor also doubles as a mask indicating which values in
the returned padded indices Tensor are indices versus padded zeros.
Args:
sparse_indices: SparseTensor of embedding lookup indices.
padded_size: Number of columns of the returned Tensors. Indices which fall
out of bounds will be truncated to the padded size.
Returns:
(sparse_indices.values padded to the specified size,
a mask the same size as the returned padded values in which 0s
indicate padded locations and 1s (or values from sparse_weights)
indicate actual values)
"""
batch_size = sparse_indices.dense_shape[0]
sparse_indices = sparse_ops.sparse_slice(sparse_indices, [0, 0],
[batch_size, padded_size])
indices, values = sparse_indices.indices, sparse_indices.values
padded_values = array_ops.scatter_nd(
indices,
math_ops.cast(values, dtypes.int32),
shape=(batch_size, padded_size))
weights = array_ops.ones_like(values, dtype=dtypes.float32)
padded_mask = array_ops.scatter_nd(
indices, weights, shape=(batch_size, padded_size))
return padded_values, padded_mask
class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2):
"""TPUEmbeddingColumn which allows serving on TensorCore."""
def __new__(cls, *args, **kwargs):
# For __new__, just capture the inference dense shape and call parent.
if 'tensor_core_shape' in kwargs:
cls._tensor_core_shape = kwargs['tensor_core_shape']
del kwargs['tensor_core_shape']
if 'embedding_lookup_device' in kwargs:
cls._embedding_lookup_device = kwargs['embedding_lookup_device']
del kwargs['embedding_lookup_device']
return _TPUEmbeddingColumnV2.__new__(cls, *args, **kwargs)
def __init__(self, *args, **kwargs):
# For __init__, just capture the inference dense shape and call parent.
if 'tensor_core_shape' in kwargs:
self._tensor_core_shape = kwargs['tensor_core_shape']
del kwargs['tensor_core_shape']
if 'embedding_lookup_device' in kwargs:
self._embedding_lookup_device = kwargs['embedding_lookup_device']
del kwargs['embedding_lookup_device']
_TPUEmbeddingColumnV2.__init__(self, *args, **kwargs)
def create_state(self, state_manager):
if (tpu.under_tpu_inference_context() and
self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
raise ValueError(
'Using embedding_lookup_device=tpu_embedding_core during inference '
'is not supported.')
if self._embedding_lookup_device == EmbeddingDevice.CPU:
if tpu.under_tpu_inference_context():
return fc_lib.EmbeddingColumn.create_state(self, state_manager)
else:
raise ValueError(
'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" '
'during training is not supported.')
return super(_TPUDeviceSpecificEmbeddingColumnV2,
self).create_state(state_manager)
def get_dense_tensor(self, transformation_cache, state_manager):
"""Private method that follows get_dense_tensor."""
# If we aren't inferencing on TensorCore, just delegate to parent.
if not tpu.under_tpu_inference_context() or not self._tensor_core_shape:
return super(_TPUDeviceSpecificEmbeddingColumnV2,
self).get_dense_tensor(transformation_cache, state_manager)
sparse_tensor = transformation_cache.get(self.categorical_column.name,
state_manager)
# Use outside compile to densify and pad the input tensors.
def host_computation():
return pad_sparse_embedding_lookup_indices(sparse_tensor,
self._tensor_core_shape[1])
values, mask = tpu.outside_compilation(host_computation)
# Do a dense embedding lookup on TensorCore.
embedding_weights = state_manager.get_variable(self, 'embedding_weights')
embedding = sparse_embedding_aggregate_slice(embedding_weights,
(values, mask),
self.get_combiner())
return embedding
class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2):
"""TPUSharedEmbeddingColumnV2 which allows serving on TensorCore."""
def __new__(cls, *args, **kwargs):
# For __new__, just capture the inference dense shape and call parent.
if 'tensor_core_shape' in kwargs:
cls._tensor_core_shape = kwargs['tensor_core_shape']
del kwargs['tensor_core_shape']
if 'embedding_lookup_device' in kwargs:
cls._embedding_lookup_device = kwargs['embedding_lookup_device']
del kwargs['embedding_lookup_device']
return _TPUSharedEmbeddingColumnV2.__new__(cls, *args, **kwargs)
def __init__(self, *args, **kwargs):
# For __init__, just capture the inference dense shape and call parent.
if 'tensor_core_shape' in kwargs:
self._tensor_core_shape = kwargs['tensor_core_shape']
del kwargs['tensor_core_shape']
if 'embedding_lookup_device' in kwargs:
self._embedding_lookup_device = kwargs['embedding_lookup_device']
del kwargs['embedding_lookup_device']
_TPUSharedEmbeddingColumnV2.__init__(self, *args, **kwargs)
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
"""Private method that follows _get_dense_tensor_internal."""
if (tpu.under_tpu_inference_context() and
self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
raise ValueError('Using embedding_lookup_device=tpu_embedding_core '
'during inference is not supported.')
if self._embedding_lookup_device == EmbeddingDevice.CPU:
if tpu.under_tpu_inference_context():
return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
self)._get_dense_tensor_internal(transformation_cache,
state_manager)
else:
raise ValueError(
'Using TPUSharedEmbeddingColumn with '
'embedding_lookup_device="cpu" during training is not supported.')
sparse_tensor = transformation_cache.get(self.categorical_column.name,
state_manager)
# Use outside compile to densify and pad the input tensors.
def host_computation():
return pad_sparse_embedding_lookup_indices(sparse_tensor,
self._tensor_core_shape[1])
values, mask = tpu.outside_compilation(host_computation)
# Do a dense embedding lookup on TensorCore.
embedding_weights = self.shared_embedding_column_creator.embedding_weights
embedding = sparse_embedding_aggregate_slice(embedding_weights,
(values, mask),
self.get_combiner())
return embedding

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.client import session
from tensorflow.python.feature_column import feature_column_lib as fc_lib
from tensorflow.python.framework import dtypes
@ -29,6 +31,7 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.tpu import feature_column_v2 as tpu_fc
from tensorflow.python.tpu import tpu
def _initialized_session():
@ -301,5 +304,204 @@ class SharedEmbeddingColumnTestV2(test.TestCase):
embedding_lookup_b[0].eval())
class DeviceSpecificEmbeddingColumnTestV2(test.TestCase,
parameterized.TestCase):
@parameterized.named_parameters(
{
'testcase_name': 'invalid_shared',
'shared': True,
}, {
'testcase_name': 'invalid_not_shared',
'shared': False,
})
@test_util.deprecated_graph_mode_only
def test_invalid_cases(self, shared):
# Inputs.
input_sparse_tensor = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1), (1, 4)),
values=(2, 0, 1, 3),
dense_shape=(2, 5))
input_features = {'inp': input_sparse_tensor}
# Build columns.
categorical_column_input = fc_lib.categorical_column_with_identity(
key='inp', num_buckets=3)
# Training on TPU with cpu embedding lookups is not supported.
if shared:
embedding_column = tpu_fc.shared_embedding_columns_v2(
[categorical_column_input],
dimension=2,
embedding_lookup_device='cpu',
tensor_core_shape=[None, 3])
else:
embedding_column = tpu_fc.embedding_column_v2(
categorical_column_input,
dimension=2,
embedding_lookup_device='cpu',
tensor_core_shape=[None, 3])
dense_features = fc_lib.DenseFeatures(embedding_column)
with self.assertRaisesRegexp(
ValueError,
r'.*embedding_lookup_device=\"cpu\" during training is not'):
dense_features(input_features)
# Inference on with TPU Embedding Hardware is not supported.
if shared:
embedding_column = tpu_fc.shared_embedding_columns_v2(
[categorical_column_input],
dimension=2,
embedding_lookup_device='tpu_embedding_core',
tensor_core_shape=[None, 3])
else:
embedding_column = tpu_fc.embedding_column_v2(
categorical_column_input,
dimension=2,
embedding_lookup_device='tpu_embedding_core',
tensor_core_shape=[None, 3])
context = tpu._TPUInferenceContext('tpu_inference')
context.Enter()
dense_features = fc_lib.DenseFeatures(embedding_column)
with self.assertRaisesRegexp(
ValueError,
r'Using embedding_lookup_device=tpu_embedding_core during inference is '
):
dense_features(input_features)
context.Exit()
@parameterized.named_parameters(
{
'testcase_name': 'combiner_mean_shared',
'shared': True,
'combiner': 'mean'
}, {
'testcase_name': 'combiner_sum_shared',
'shared': True,
'combiner': 'sum'
}, {
'testcase_name': 'combiner_sqrtn_shared',
'shared': True,
'combiner': 'sqrtn'
}, {
'testcase_name': 'combiner_mean_not_shared',
'shared': False,
'combiner': 'mean'
}, {
'testcase_name': 'combiner_sum_not_shared',
'shared': False,
'combiner': 'sum'
}, {
'testcase_name': 'combiner_sqrtn_not_shared',
'shared': False,
'combiner': 'sqrtn'
})
@test_util.deprecated_graph_mode_only
def test_dense_embedding_lookup(self, shared, combiner):
# Inputs.
vocabulary_size = 3
input_sparse_tensor = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
# example 1, ids [0, 1, 3]
indices=((0, 0), (1, 0), (1, 1), (1, 4)),
values=(2, 0, 1, 3),
dense_shape=(2, 5))
input_features = {'inp': input_sparse_tensor}
# Embedding variable.
embedding_dimension = 2
embedding_values = (
(1., 2.), # id 0
(3., 5.), # id 1
(7., 11.), # id 2
(13., 17.) # id 3
)
def _initializer(shape, dtype, partition_info=None):
self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
self.assertEqual(dtypes.float32, dtype)
self.assertIsNone(partition_info)
return embedding_values
# Build columns.
categorical_column_input = fc_lib.categorical_column_with_identity(
key='inp', num_buckets=vocabulary_size)
# Set tensor_core_shape to be [None, 20] to ensure some padding and
# dynamic batch size.
if shared:
embedding_column = tpu_fc.shared_embedding_columns_v2(
[categorical_column_input],
dimension=embedding_dimension,
initializer=_initializer,
combiner=combiner,
embedding_lookup_device='tpu_tensor_core',
tensor_core_shape=[None, 3])
else:
embedding_column = tpu_fc.embedding_column_v2(
categorical_column_input,
dimension=embedding_dimension,
initializer=_initializer,
combiner=combiner,
embedding_lookup_device='tpu_tensor_core',
tensor_core_shape=[None, 3])
# Run in TPUInferenceContext so that we hit the intended densification case.
context = tpu._TPUInferenceContext('tpu_inference')
context.Enter()
dense_features = fc_lib.DenseFeatures(embedding_column)
# Sqrtn combiner not supported for now.
if combiner == 'sqrtn':
with self.assertRaisesRegexp(
ValueError, 'Dense TPU Embedding does not support combiner'):
embedding_lookup = dense_features(input_features)
return
if combiner == 'mean':
expected_lookups = (
# example 0:
(7., 11.), # ids [2], embedding = [7, 11]
# example 1:
(2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
)
elif combiner == 'sum':
expected_lookups = (
# example 0:
(7., 11.), # ids [2], embedding = [7, 11]
# example 1:
(4., 7), # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7]
)
embedding_lookup = dense_features(input_features)
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
if shared:
self.assertCountEqual(('inp_shared_embedding:0',),
tuple([v.name for v in global_vars]))
else:
self.assertCountEqual(
('dense_features/inp_embedding/embedding_weights:0',),
tuple([v.name for v in global_vars]))
embedding_var = global_vars[0]
with _initialized_session():
self.assertAllEqual(embedding_values, embedding_var.eval())
eval_res = embedding_lookup.eval()
self.assertAllEqual(expected_lookups, eval_res)
context.Exit()
@test_util.deprecated_graph_mode_only
def test_error_dense_shape_invalid(self):
categorical_column_input = fc_lib.categorical_column_with_identity(
key='inp', num_buckets=5)
with self.assertRaisesRegexp(ValueError,
'tensor_core_shape must be size 2'):
tpu_fc.shared_embedding_columns_v2([categorical_column_input],
dimension=20,
tensor_core_shape=[None, 20, 15])
if __name__ == '__main__':
test.main()

View File

@ -18,7 +18,7 @@ tf_module {
}
member_method {
name: "embedding_column"
argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'max_sequence_length\', \'learning_rate_fn\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'0\', \'None\'], "
argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'max_sequence_length\', \'learning_rate_fn\', \'embedding_lookup_device\', \'tensor_core_shape\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'0\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "initialize_tpu_system"
@ -26,6 +26,6 @@ tf_module {
}
member_method {
name: "shared_embedding_columns"
argspec: "args=[\'categorical_columns\', \'dimension\', \'combiner\', \'initializer\', \'shared_embedding_collection_name\', \'max_sequence_lengths\', \'learning_rate_fn\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'categorical_columns\', \'dimension\', \'combiner\', \'initializer\', \'shared_embedding_collection_name\', \'max_sequence_lengths\', \'learning_rate_fn\', \'embedding_lookup_device\', \'tensor_core_shape\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
}