Expose TPU embedding partition strategy as on option that can be set (either mod or div). Add support for this into TPU feature column API.

PiperOrigin-RevId: 245485369
This commit is contained in:
Bruce Fontaine 2019-04-26 14:14:09 -07:00 committed by TensorFlower Gardener
parent ce22533806
commit dfefb28735
2 changed files with 75 additions and 13 deletions

View File

@ -50,7 +50,8 @@ def embedding_column(categorical_column,
dimension, dimension,
combiner='mean', combiner='mean',
initializer=None, initializer=None,
max_sequence_length=0): max_sequence_length=0,
partition_strategy='div'):
"""TPU embedding_column for `tf.feature_column.embedding_column`. """TPU embedding_column for `tf.feature_column.embedding_column`.
Note that the interface for TPU embedding_column is different from the non-TPU Note that the interface for TPU embedding_column is different from the non-TPU
@ -77,6 +78,11 @@ def embedding_column(categorical_column,
length. Any sequence shorter then this will be padded with 0 embeddings length. Any sequence shorter then this will be padded with 0 embeddings
and any sequence longer will be truncated. This must be positive for and any sequence longer will be truncated. This must be positive for
sequence features and 0 for non-sequence features. sequence features and 0 for non-sequence features.
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
`"div"` and `"mod"'. If `"mod"` is used, evaluation and exporting the
model to CPU will not work. In order to do this, you must shuffle the
embedding tensors into a single shard.
Returns: Returns:
A _TPUEmbeddingColumn. A _TPUEmbeddingColumn.
@ -122,7 +128,8 @@ def embedding_column(categorical_column,
tensor_name_in_ckpt=None, tensor_name_in_ckpt=None,
max_norm=None, max_norm=None,
trainable=True, trainable=True,
max_sequence_length=max_sequence_length) max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
# For Embedding column, the initializer is hidden inside the creator Fn, which # For Embedding column, the initializer is hidden inside the creator Fn, which
# is not accessiable later. So, we attach it to a speicial field. Also note # is not accessiable later. So, we attach it to a speicial field. Also note
# that non-TPU Embedding column and non-TPU shared Embedding column handle the # that non-TPU Embedding column and non-TPU shared Embedding column handle the
@ -136,7 +143,8 @@ def shared_embedding_columns(categorical_columns,
combiner='mean', combiner='mean',
initializer=None, initializer=None,
shared_embedding_collection_name=None, shared_embedding_collection_name=None,
max_sequence_lengths=None): max_sequence_lengths=None,
partition_strategy='div'):
"""List of dense columns that convert from sparse, categorical input. """List of dense columns that convert from sparse, categorical input.
Note that the interface for TPU embedding_column is different from the non-TPU Note that the interface for TPU embedding_column is different from the non-TPU
@ -169,6 +177,9 @@ def shared_embedding_columns(categorical_columns,
to sequence columns specify the max sequence length for the column. Any to sequence columns specify the max sequence length for the column. Any
sequence shorter then this will be padded with 0 embeddings and any sequence shorter then this will be padded with 0 embeddings and any
sequence longer will be truncated. sequence longer will be truncated.
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
`"div"` and `"mod"'.
Returns: Returns:
A _TPUEmbeddingColumn. A _TPUEmbeddingColumn.
@ -238,7 +249,8 @@ def shared_embedding_columns(categorical_columns,
tensor_name_in_ckpt=None, tensor_name_in_ckpt=None,
max_norm=None, max_norm=None,
trainable=True, trainable=True,
max_sequence_length=max_sequence_length) max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
tpu_columns.append(column) tpu_columns.append(column)
return tpu_columns return tpu_columns
@ -247,7 +259,8 @@ def shared_embedding_columns(categorical_columns,
class _TPUBaseEmbeddingColumn(object): class _TPUBaseEmbeddingColumn(object):
"""Base class for TPU Embedding Column.""" """Base class for TPU Embedding Column."""
def __init__(self, categorical_column, max_sequence_length=0): def __init__(self, categorical_column, max_sequence_length=0,
partition_strategy='div'):
self._tpu_categorical_column = categorical_column self._tpu_categorical_column = categorical_column
self._max_sequence_length = max_sequence_length self._max_sequence_length = max_sequence_length
if (self.is_sequence_column() and max_sequence_length < 1): if (self.is_sequence_column() and max_sequence_length < 1):
@ -259,6 +272,10 @@ class _TPUBaseEmbeddingColumn(object):
raise ValueError('Non zero max_seq_length={} specified for non ' raise ValueError('Non zero max_seq_length={} specified for non '
'sequence column {}.'.format(max_sequence_length, 'sequence column {}.'.format(max_sequence_length,
categorical_column.name)) categorical_column.name))
self._partition_strategy = partition_strategy
if partition_strategy not in ('mod', 'div'):
raise ValueError('partition_strategy must be one of `mod` or `div`. '
'Received {}.'.format(partition_strategy))
def get_combiner(self): def get_combiner(self):
"""Returns the embedding combiner.""" """Returns the embedding combiner."""
@ -303,6 +320,9 @@ class _TPUBaseEmbeddingColumn(object):
return get_sequence_length_feature_key_name_from_feature_key_name( return get_sequence_length_feature_key_name_from_feature_key_name(
self.get_feature_key_name()) self.get_feature_key_name())
def get_partition_strategy(self):
return self._partition_strategy
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
"""Core Embedding Column.""" """Core Embedding Column."""
@ -316,7 +336,8 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
tensor_name_in_ckpt=None, tensor_name_in_ckpt=None,
max_norm=None, max_norm=None,
trainable=True, trainable=True,
max_sequence_length=0): max_sequence_length=0,
partition_strategy='div'):
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
# are not supported on TPU. They are solely for matching the signature of # are not supported on TPU. They are solely for matching the signature of
# __new__ of parent class fc._EmbeddingColumn. # __new__ of parent class fc._EmbeddingColumn.
@ -340,9 +361,11 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
tensor_name_in_ckpt=None, tensor_name_in_ckpt=None,
max_norm=None, max_norm=None,
trainable=True, trainable=True,
max_sequence_length=0): max_sequence_length=0,
partition_strategy='div'):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column, _TPUBaseEmbeddingColumn.__init__(self, categorical_column,
max_sequence_length=max_sequence_length) max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
self._key = None self._key = None
def get_combiner(self): def get_combiner(self):
@ -383,12 +406,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context(): if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation(): def host_computation():
return fc._EmbeddingColumn._get_dense_tensor( return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable) self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation) return tpu.outside_compilation(host_computation)
if _is_running_on_cpu(): if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._EmbeddingColumn._get_dense_tensor( return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable) self, inputs, weight_collections, trainable)
@ -405,12 +434,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
def _get_sequence_dense_tensor( def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None): self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context(): if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation(): def host_computation():
return fc._EmbeddingColumn._get_sequence_dense_tensor( return fc._EmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable) self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation) return tpu.outside_compilation(host_computation)
if _is_running_on_cpu(): if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._EmbeddingColumn._get_sequence_dense_tensor( return fc._EmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable) self, inputs, weight_collections, trainable)
@ -443,7 +478,8 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
tensor_name_in_ckpt=None, tensor_name_in_ckpt=None,
max_norm=None, max_norm=None,
trainable=True, trainable=True,
max_sequence_length=0): max_sequence_length=0,
partition_strategy='div'):
return fc._SharedEmbeddingColumn.__new__( return fc._SharedEmbeddingColumn.__new__(
cls, cls,
categorical_column, categorical_column,
@ -466,10 +502,12 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
tensor_name_in_ckpt=None, tensor_name_in_ckpt=None,
max_norm=None, max_norm=None,
trainable=True, trainable=True,
max_sequence_length=0): max_sequence_length=0,
partition_strategy='div'):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column, _TPUBaseEmbeddingColumn.__init__(self, categorical_column,
max_sequence_length=max_sequence_length) max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
self._key = None self._key = None
def get_combiner(self): def get_combiner(self):
@ -510,12 +548,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context(): if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation(): def host_computation():
return fc._SharedEmbeddingColumn._get_dense_tensor( return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable) self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation) return tpu.outside_compilation(host_computation)
if _is_running_on_cpu(): if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._SharedEmbeddingColumn._get_dense_tensor( return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable) self, inputs, weight_collections, trainable)
@ -533,12 +577,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
def _get_sequence_dense_tensor( def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None): self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context(): if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation(): def host_computation():
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor( return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable) self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation) return tpu.outside_compilation(host_computation)
if _is_running_on_cpu(): if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor( return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable) self, inputs, weight_collections, trainable)

View File

@ -412,7 +412,8 @@ class TPUEmbedding(object):
master, master,
optimization_parameters=None, optimization_parameters=None,
cluster_def=None, cluster_def=None,
pipeline_execution_with_tensor_core=False): pipeline_execution_with_tensor_core=False,
partition_strategy='div'):
"""API for using TPU for embedding lookups. """API for using TPU for embedding lookups.
Args: Args:
@ -433,10 +434,18 @@ class TPUEmbedding(object):
faster, but trained model will be different if step N and step N+1 faster, but trained model will be different if step N and step N+1
involve the same set of embedding IDs. Please see involve the same set of embedding IDs. Please see
`tpu_embedding_configuration.proto` for details. `tpu_embedding_configuration.proto` for details.
partition_strategy: A string, either 'mod' or 'div', specifying how to map
the lookup id to the embedding tensor. For more information see
`tf.nn.embedding_lookup_sparse`.
Raises: Raises:
ValueError: if any input is invalid. ValueError: if any input is invalid.
""" """
if partition_strategy not in ('div', 'mod'):
raise ValueError(
'Invalid partition_strategy {}'.format(partition_strategy))
self._partition_strategy = partition_strategy
_validate_table_to_config_dict(table_to_config_dict) _validate_table_to_config_dict(table_to_config_dict)
# Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
@ -598,7 +607,10 @@ class TPUEmbedding(object):
config_proto.batch_size_per_tensor_core = self._batch_size_per_core config_proto.batch_size_per_tensor_core = self._batch_size_per_core
config_proto.num_hosts = self._num_hosts config_proto.num_hosts = self._num_hosts
config_proto.num_tensor_cores = self._num_cores config_proto.num_tensor_cores = self._num_cores
config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT config_proto.sharding_strategy = (
elc.TPUEmbeddingConfiguration.DIV_DEFAULT
if self._partition_strategy == 'div' else
elc.TPUEmbeddingConfiguration.MOD)
config_proto.pipeline_execution_with_tensor_core = ( config_proto.pipeline_execution_with_tensor_core = (
self._pipeline_execution_with_tensor_core) self._pipeline_execution_with_tensor_core)