Expose TPU embedding partition strategy as on option that can be set (either mod or div). Add support for this into TPU feature column API.
PiperOrigin-RevId: 245485369
This commit is contained in:
parent
ce22533806
commit
dfefb28735
@ -50,7 +50,8 @@ def embedding_column(categorical_column,
|
||||
dimension,
|
||||
combiner='mean',
|
||||
initializer=None,
|
||||
max_sequence_length=0):
|
||||
max_sequence_length=0,
|
||||
partition_strategy='div'):
|
||||
"""TPU embedding_column for `tf.feature_column.embedding_column`.
|
||||
|
||||
Note that the interface for TPU embedding_column is different from the non-TPU
|
||||
@ -77,6 +78,11 @@ def embedding_column(categorical_column,
|
||||
length. Any sequence shorter then this will be padded with 0 embeddings
|
||||
and any sequence longer will be truncated. This must be positive for
|
||||
sequence features and 0 for non-sequence features.
|
||||
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
|
||||
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
|
||||
`"div"` and `"mod"'. If `"mod"` is used, evaluation and exporting the
|
||||
model to CPU will not work. In order to do this, you must shuffle the
|
||||
embedding tensors into a single shard.
|
||||
|
||||
Returns:
|
||||
A _TPUEmbeddingColumn.
|
||||
@ -122,7 +128,8 @@ def embedding_column(categorical_column,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True,
|
||||
max_sequence_length=max_sequence_length)
|
||||
max_sequence_length=max_sequence_length,
|
||||
partition_strategy=partition_strategy)
|
||||
# For Embedding column, the initializer is hidden inside the creator Fn, which
|
||||
# is not accessiable later. So, we attach it to a speicial field. Also note
|
||||
# that non-TPU Embedding column and non-TPU shared Embedding column handle the
|
||||
@ -136,7 +143,8 @@ def shared_embedding_columns(categorical_columns,
|
||||
combiner='mean',
|
||||
initializer=None,
|
||||
shared_embedding_collection_name=None,
|
||||
max_sequence_lengths=None):
|
||||
max_sequence_lengths=None,
|
||||
partition_strategy='div'):
|
||||
"""List of dense columns that convert from sparse, categorical input.
|
||||
|
||||
Note that the interface for TPU embedding_column is different from the non-TPU
|
||||
@ -169,6 +177,9 @@ def shared_embedding_columns(categorical_columns,
|
||||
to sequence columns specify the max sequence length for the column. Any
|
||||
sequence shorter then this will be padded with 0 embeddings and any
|
||||
sequence longer will be truncated.
|
||||
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
|
||||
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
|
||||
`"div"` and `"mod"'.
|
||||
|
||||
Returns:
|
||||
A _TPUEmbeddingColumn.
|
||||
@ -238,7 +249,8 @@ def shared_embedding_columns(categorical_columns,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True,
|
||||
max_sequence_length=max_sequence_length)
|
||||
max_sequence_length=max_sequence_length,
|
||||
partition_strategy=partition_strategy)
|
||||
tpu_columns.append(column)
|
||||
|
||||
return tpu_columns
|
||||
@ -247,7 +259,8 @@ def shared_embedding_columns(categorical_columns,
|
||||
class _TPUBaseEmbeddingColumn(object):
|
||||
"""Base class for TPU Embedding Column."""
|
||||
|
||||
def __init__(self, categorical_column, max_sequence_length=0):
|
||||
def __init__(self, categorical_column, max_sequence_length=0,
|
||||
partition_strategy='div'):
|
||||
self._tpu_categorical_column = categorical_column
|
||||
self._max_sequence_length = max_sequence_length
|
||||
if (self.is_sequence_column() and max_sequence_length < 1):
|
||||
@ -259,6 +272,10 @@ class _TPUBaseEmbeddingColumn(object):
|
||||
raise ValueError('Non zero max_seq_length={} specified for non '
|
||||
'sequence column {}.'.format(max_sequence_length,
|
||||
categorical_column.name))
|
||||
self._partition_strategy = partition_strategy
|
||||
if partition_strategy not in ('mod', 'div'):
|
||||
raise ValueError('partition_strategy must be one of `mod` or `div`. '
|
||||
'Received {}.'.format(partition_strategy))
|
||||
|
||||
def get_combiner(self):
|
||||
"""Returns the embedding combiner."""
|
||||
@ -303,6 +320,9 @@ class _TPUBaseEmbeddingColumn(object):
|
||||
return get_sequence_length_feature_key_name_from_feature_key_name(
|
||||
self.get_feature_key_name())
|
||||
|
||||
def get_partition_strategy(self):
|
||||
return self._partition_strategy
|
||||
|
||||
|
||||
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
"""Core Embedding Column."""
|
||||
@ -316,7 +336,8 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True,
|
||||
max_sequence_length=0):
|
||||
max_sequence_length=0,
|
||||
partition_strategy='div'):
|
||||
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
|
||||
# are not supported on TPU. They are solely for matching the signature of
|
||||
# __new__ of parent class fc._EmbeddingColumn.
|
||||
@ -340,9 +361,11 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True,
|
||||
max_sequence_length=0):
|
||||
max_sequence_length=0,
|
||||
partition_strategy='div'):
|
||||
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
|
||||
max_sequence_length=max_sequence_length)
|
||||
max_sequence_length=max_sequence_length,
|
||||
partition_strategy=partition_strategy)
|
||||
self._key = None
|
||||
|
||||
def get_combiner(self):
|
||||
@ -383,12 +406,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
if tpu.under_tpu_inference_context():
|
||||
if self._partition_strategy == 'mod':
|
||||
raise NotImplementedError('Export saved model does not support MOD '
|
||||
'sharded embeddings.')
|
||||
def host_computation():
|
||||
return fc._EmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
return tpu.outside_compilation(host_computation)
|
||||
|
||||
if _is_running_on_cpu():
|
||||
if self._partition_strategy == 'mod':
|
||||
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
|
||||
'sharded embeddings.')
|
||||
return fc._EmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
|
||||
@ -405,12 +434,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
def _get_sequence_dense_tensor(
|
||||
self, inputs, weight_collections=None, trainable=None):
|
||||
if tpu.under_tpu_inference_context():
|
||||
if self._partition_strategy == 'mod':
|
||||
raise NotImplementedError('Export saved model does not support MOD '
|
||||
'sharded embeddings.')
|
||||
def host_computation():
|
||||
return fc._EmbeddingColumn._get_sequence_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
return tpu.outside_compilation(host_computation)
|
||||
|
||||
if _is_running_on_cpu():
|
||||
if self._partition_strategy == 'mod':
|
||||
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
|
||||
'sharded embeddings.')
|
||||
return fc._EmbeddingColumn._get_sequence_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
|
||||
@ -443,7 +478,8 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True,
|
||||
max_sequence_length=0):
|
||||
max_sequence_length=0,
|
||||
partition_strategy='div'):
|
||||
return fc._SharedEmbeddingColumn.__new__(
|
||||
cls,
|
||||
categorical_column,
|
||||
@ -466,10 +502,12 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
||||
tensor_name_in_ckpt=None,
|
||||
max_norm=None,
|
||||
trainable=True,
|
||||
max_sequence_length=0):
|
||||
max_sequence_length=0,
|
||||
partition_strategy='div'):
|
||||
|
||||
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
|
||||
max_sequence_length=max_sequence_length)
|
||||
max_sequence_length=max_sequence_length,
|
||||
partition_strategy=partition_strategy)
|
||||
self._key = None
|
||||
|
||||
def get_combiner(self):
|
||||
@ -510,12 +548,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
if tpu.under_tpu_inference_context():
|
||||
if self._partition_strategy == 'mod':
|
||||
raise NotImplementedError('Export saved model does not support MOD '
|
||||
'sharded embeddings.')
|
||||
def host_computation():
|
||||
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
return tpu.outside_compilation(host_computation)
|
||||
|
||||
if _is_running_on_cpu():
|
||||
if self._partition_strategy == 'mod':
|
||||
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
|
||||
'sharded embeddings.')
|
||||
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
|
||||
@ -533,12 +577,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
||||
def _get_sequence_dense_tensor(
|
||||
self, inputs, weight_collections=None, trainable=None):
|
||||
if tpu.under_tpu_inference_context():
|
||||
if self._partition_strategy == 'mod':
|
||||
raise NotImplementedError('Export saved model does not support MOD '
|
||||
'sharded embeddings.')
|
||||
def host_computation():
|
||||
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
return tpu.outside_compilation(host_computation)
|
||||
|
||||
if _is_running_on_cpu():
|
||||
if self._partition_strategy == 'mod':
|
||||
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
|
||||
'sharded embeddings.')
|
||||
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
|
||||
self, inputs, weight_collections, trainable)
|
||||
|
||||
|
@ -412,7 +412,8 @@ class TPUEmbedding(object):
|
||||
master,
|
||||
optimization_parameters=None,
|
||||
cluster_def=None,
|
||||
pipeline_execution_with_tensor_core=False):
|
||||
pipeline_execution_with_tensor_core=False,
|
||||
partition_strategy='div'):
|
||||
"""API for using TPU for embedding lookups.
|
||||
|
||||
Args:
|
||||
@ -433,10 +434,18 @@ class TPUEmbedding(object):
|
||||
faster, but trained model will be different if step N and step N+1
|
||||
involve the same set of embedding IDs. Please see
|
||||
`tpu_embedding_configuration.proto` for details.
|
||||
partition_strategy: A string, either 'mod' or 'div', specifying how to map
|
||||
the lookup id to the embedding tensor. For more information see
|
||||
`tf.nn.embedding_lookup_sparse`.
|
||||
|
||||
Raises:
|
||||
ValueError: if any input is invalid.
|
||||
"""
|
||||
if partition_strategy not in ('div', 'mod'):
|
||||
raise ValueError(
|
||||
'Invalid partition_strategy {}'.format(partition_strategy))
|
||||
self._partition_strategy = partition_strategy
|
||||
|
||||
_validate_table_to_config_dict(table_to_config_dict)
|
||||
# Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
|
||||
self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
|
||||
@ -598,7 +607,10 @@ class TPUEmbedding(object):
|
||||
config_proto.batch_size_per_tensor_core = self._batch_size_per_core
|
||||
config_proto.num_hosts = self._num_hosts
|
||||
config_proto.num_tensor_cores = self._num_cores
|
||||
config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT
|
||||
config_proto.sharding_strategy = (
|
||||
elc.TPUEmbeddingConfiguration.DIV_DEFAULT
|
||||
if self._partition_strategy == 'div' else
|
||||
elc.TPUEmbeddingConfiguration.MOD)
|
||||
config_proto.pipeline_execution_with_tensor_core = (
|
||||
self._pipeline_execution_with_tensor_core)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user