Expose TPU embedding partition strategy as on option that can be set (either mod or div). Add support for this into TPU feature column API.

PiperOrigin-RevId: 245485369
This commit is contained in:
Bruce Fontaine 2019-04-26 14:14:09 -07:00 committed by TensorFlower Gardener
parent ce22533806
commit dfefb28735
2 changed files with 75 additions and 13 deletions

View File

@ -50,7 +50,8 @@ def embedding_column(categorical_column,
dimension,
combiner='mean',
initializer=None,
max_sequence_length=0):
max_sequence_length=0,
partition_strategy='div'):
"""TPU embedding_column for `tf.feature_column.embedding_column`.
Note that the interface for TPU embedding_column is different from the non-TPU
@ -77,6 +78,11 @@ def embedding_column(categorical_column,
length. Any sequence shorter then this will be padded with 0 embeddings
and any sequence longer will be truncated. This must be positive for
sequence features and 0 for non-sequence features.
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
`"div"` and `"mod"'. If `"mod"` is used, evaluation and exporting the
model to CPU will not work. In order to do this, you must shuffle the
embedding tensors into a single shard.
Returns:
A _TPUEmbeddingColumn.
@ -122,7 +128,8 @@ def embedding_column(categorical_column,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=max_sequence_length)
max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
# For Embedding column, the initializer is hidden inside the creator Fn, which
# is not accessiable later. So, we attach it to a speicial field. Also note
# that non-TPU Embedding column and non-TPU shared Embedding column handle the
@ -136,7 +143,8 @@ def shared_embedding_columns(categorical_columns,
combiner='mean',
initializer=None,
shared_embedding_collection_name=None,
max_sequence_lengths=None):
max_sequence_lengths=None,
partition_strategy='div'):
"""List of dense columns that convert from sparse, categorical input.
Note that the interface for TPU embedding_column is different from the non-TPU
@ -169,6 +177,9 @@ def shared_embedding_columns(categorical_columns,
to sequence columns specify the max sequence length for the column. Any
sequence shorter then this will be padded with 0 embeddings and any
sequence longer will be truncated.
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
`"div"` and `"mod"'.
Returns:
A _TPUEmbeddingColumn.
@ -238,7 +249,8 @@ def shared_embedding_columns(categorical_columns,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=max_sequence_length)
max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
tpu_columns.append(column)
return tpu_columns
@ -247,7 +259,8 @@ def shared_embedding_columns(categorical_columns,
class _TPUBaseEmbeddingColumn(object):
"""Base class for TPU Embedding Column."""
def __init__(self, categorical_column, max_sequence_length=0):
def __init__(self, categorical_column, max_sequence_length=0,
partition_strategy='div'):
self._tpu_categorical_column = categorical_column
self._max_sequence_length = max_sequence_length
if (self.is_sequence_column() and max_sequence_length < 1):
@ -259,6 +272,10 @@ class _TPUBaseEmbeddingColumn(object):
raise ValueError('Non zero max_seq_length={} specified for non '
'sequence column {}.'.format(max_sequence_length,
categorical_column.name))
self._partition_strategy = partition_strategy
if partition_strategy not in ('mod', 'div'):
raise ValueError('partition_strategy must be one of `mod` or `div`. '
'Received {}.'.format(partition_strategy))
def get_combiner(self):
"""Returns the embedding combiner."""
@ -303,6 +320,9 @@ class _TPUBaseEmbeddingColumn(object):
return get_sequence_length_feature_key_name_from_feature_key_name(
self.get_feature_key_name())
def get_partition_strategy(self):
return self._partition_strategy
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
"""Core Embedding Column."""
@ -316,7 +336,8 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=0):
max_sequence_length=0,
partition_strategy='div'):
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
# are not supported on TPU. They are solely for matching the signature of
# __new__ of parent class fc._EmbeddingColumn.
@ -340,9 +361,11 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=0):
max_sequence_length=0,
partition_strategy='div'):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
max_sequence_length=max_sequence_length)
max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
self._key = None
def get_combiner(self):
@ -383,12 +406,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation():
return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
@ -405,12 +434,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation():
return fc._EmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._EmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable)
@ -443,7 +478,8 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=0):
max_sequence_length=0,
partition_strategy='div'):
return fc._SharedEmbeddingColumn.__new__(
cls,
categorical_column,
@ -466,10 +502,12 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=0):
max_sequence_length=0,
partition_strategy='div'):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
max_sequence_length=max_sequence_length)
max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
self._key = None
def get_combiner(self):
@ -510,12 +548,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation():
return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
@ -533,12 +577,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation():
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable)

View File

@ -412,7 +412,8 @@ class TPUEmbedding(object):
master,
optimization_parameters=None,
cluster_def=None,
pipeline_execution_with_tensor_core=False):
pipeline_execution_with_tensor_core=False,
partition_strategy='div'):
"""API for using TPU for embedding lookups.
Args:
@ -433,10 +434,18 @@ class TPUEmbedding(object):
faster, but trained model will be different if step N and step N+1
involve the same set of embedding IDs. Please see
`tpu_embedding_configuration.proto` for details.
partition_strategy: A string, either 'mod' or 'div', specifying how to map
the lookup id to the embedding tensor. For more information see
`tf.nn.embedding_lookup_sparse`.
Raises:
ValueError: if any input is invalid.
"""
if partition_strategy not in ('div', 'mod'):
raise ValueError(
'Invalid partition_strategy {}'.format(partition_strategy))
self._partition_strategy = partition_strategy
_validate_table_to_config_dict(table_to_config_dict)
# Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
@ -598,7 +607,10 @@ class TPUEmbedding(object):
config_proto.batch_size_per_tensor_core = self._batch_size_per_core
config_proto.num_hosts = self._num_hosts
config_proto.num_tensor_cores = self._num_cores
config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT
config_proto.sharding_strategy = (
elc.TPUEmbeddingConfiguration.DIV_DEFAULT
if self._partition_strategy == 'div' else
elc.TPUEmbeddingConfiguration.MOD)
config_proto.pipeline_execution_with_tensor_core = (
self._pipeline_execution_with_tensor_core)