Expose TPU embedding partition strategy as on option that can be set (either mod or div). Add support for this into TPU feature column API.
PiperOrigin-RevId: 245485369
This commit is contained in:
parent
ce22533806
commit
dfefb28735
@ -50,7 +50,8 @@ def embedding_column(categorical_column,
|
|||||||
dimension,
|
dimension,
|
||||||
combiner='mean',
|
combiner='mean',
|
||||||
initializer=None,
|
initializer=None,
|
||||||
max_sequence_length=0):
|
max_sequence_length=0,
|
||||||
|
partition_strategy='div'):
|
||||||
"""TPU embedding_column for `tf.feature_column.embedding_column`.
|
"""TPU embedding_column for `tf.feature_column.embedding_column`.
|
||||||
|
|
||||||
Note that the interface for TPU embedding_column is different from the non-TPU
|
Note that the interface for TPU embedding_column is different from the non-TPU
|
||||||
@ -77,6 +78,11 @@ def embedding_column(categorical_column,
|
|||||||
length. Any sequence shorter then this will be padded with 0 embeddings
|
length. Any sequence shorter then this will be padded with 0 embeddings
|
||||||
and any sequence longer will be truncated. This must be positive for
|
and any sequence longer will be truncated. This must be positive for
|
||||||
sequence features and 0 for non-sequence features.
|
sequence features and 0 for non-sequence features.
|
||||||
|
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
|
||||||
|
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
|
||||||
|
`"div"` and `"mod"'. If `"mod"` is used, evaluation and exporting the
|
||||||
|
model to CPU will not work. In order to do this, you must shuffle the
|
||||||
|
embedding tensors into a single shard.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A _TPUEmbeddingColumn.
|
A _TPUEmbeddingColumn.
|
||||||
@ -122,7 +128,8 @@ def embedding_column(categorical_column,
|
|||||||
tensor_name_in_ckpt=None,
|
tensor_name_in_ckpt=None,
|
||||||
max_norm=None,
|
max_norm=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
max_sequence_length=max_sequence_length)
|
max_sequence_length=max_sequence_length,
|
||||||
|
partition_strategy=partition_strategy)
|
||||||
# For Embedding column, the initializer is hidden inside the creator Fn, which
|
# For Embedding column, the initializer is hidden inside the creator Fn, which
|
||||||
# is not accessiable later. So, we attach it to a speicial field. Also note
|
# is not accessiable later. So, we attach it to a speicial field. Also note
|
||||||
# that non-TPU Embedding column and non-TPU shared Embedding column handle the
|
# that non-TPU Embedding column and non-TPU shared Embedding column handle the
|
||||||
@ -136,7 +143,8 @@ def shared_embedding_columns(categorical_columns,
|
|||||||
combiner='mean',
|
combiner='mean',
|
||||||
initializer=None,
|
initializer=None,
|
||||||
shared_embedding_collection_name=None,
|
shared_embedding_collection_name=None,
|
||||||
max_sequence_lengths=None):
|
max_sequence_lengths=None,
|
||||||
|
partition_strategy='div'):
|
||||||
"""List of dense columns that convert from sparse, categorical input.
|
"""List of dense columns that convert from sparse, categorical input.
|
||||||
|
|
||||||
Note that the interface for TPU embedding_column is different from the non-TPU
|
Note that the interface for TPU embedding_column is different from the non-TPU
|
||||||
@ -169,6 +177,9 @@ def shared_embedding_columns(categorical_columns,
|
|||||||
to sequence columns specify the max sequence length for the column. Any
|
to sequence columns specify the max sequence length for the column. Any
|
||||||
sequence shorter then this will be padded with 0 embeddings and any
|
sequence shorter then this will be padded with 0 embeddings and any
|
||||||
sequence longer will be truncated.
|
sequence longer will be truncated.
|
||||||
|
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
|
||||||
|
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
|
||||||
|
`"div"` and `"mod"'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A _TPUEmbeddingColumn.
|
A _TPUEmbeddingColumn.
|
||||||
@ -238,7 +249,8 @@ def shared_embedding_columns(categorical_columns,
|
|||||||
tensor_name_in_ckpt=None,
|
tensor_name_in_ckpt=None,
|
||||||
max_norm=None,
|
max_norm=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
max_sequence_length=max_sequence_length)
|
max_sequence_length=max_sequence_length,
|
||||||
|
partition_strategy=partition_strategy)
|
||||||
tpu_columns.append(column)
|
tpu_columns.append(column)
|
||||||
|
|
||||||
return tpu_columns
|
return tpu_columns
|
||||||
@ -247,7 +259,8 @@ def shared_embedding_columns(categorical_columns,
|
|||||||
class _TPUBaseEmbeddingColumn(object):
|
class _TPUBaseEmbeddingColumn(object):
|
||||||
"""Base class for TPU Embedding Column."""
|
"""Base class for TPU Embedding Column."""
|
||||||
|
|
||||||
def __init__(self, categorical_column, max_sequence_length=0):
|
def __init__(self, categorical_column, max_sequence_length=0,
|
||||||
|
partition_strategy='div'):
|
||||||
self._tpu_categorical_column = categorical_column
|
self._tpu_categorical_column = categorical_column
|
||||||
self._max_sequence_length = max_sequence_length
|
self._max_sequence_length = max_sequence_length
|
||||||
if (self.is_sequence_column() and max_sequence_length < 1):
|
if (self.is_sequence_column() and max_sequence_length < 1):
|
||||||
@ -259,6 +272,10 @@ class _TPUBaseEmbeddingColumn(object):
|
|||||||
raise ValueError('Non zero max_seq_length={} specified for non '
|
raise ValueError('Non zero max_seq_length={} specified for non '
|
||||||
'sequence column {}.'.format(max_sequence_length,
|
'sequence column {}.'.format(max_sequence_length,
|
||||||
categorical_column.name))
|
categorical_column.name))
|
||||||
|
self._partition_strategy = partition_strategy
|
||||||
|
if partition_strategy not in ('mod', 'div'):
|
||||||
|
raise ValueError('partition_strategy must be one of `mod` or `div`. '
|
||||||
|
'Received {}.'.format(partition_strategy))
|
||||||
|
|
||||||
def get_combiner(self):
|
def get_combiner(self):
|
||||||
"""Returns the embedding combiner."""
|
"""Returns the embedding combiner."""
|
||||||
@ -303,6 +320,9 @@ class _TPUBaseEmbeddingColumn(object):
|
|||||||
return get_sequence_length_feature_key_name_from_feature_key_name(
|
return get_sequence_length_feature_key_name_from_feature_key_name(
|
||||||
self.get_feature_key_name())
|
self.get_feature_key_name())
|
||||||
|
|
||||||
|
def get_partition_strategy(self):
|
||||||
|
return self._partition_strategy
|
||||||
|
|
||||||
|
|
||||||
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||||
"""Core Embedding Column."""
|
"""Core Embedding Column."""
|
||||||
@ -316,7 +336,8 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
|||||||
tensor_name_in_ckpt=None,
|
tensor_name_in_ckpt=None,
|
||||||
max_norm=None,
|
max_norm=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
max_sequence_length=0):
|
max_sequence_length=0,
|
||||||
|
partition_strategy='div'):
|
||||||
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
|
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
|
||||||
# are not supported on TPU. They are solely for matching the signature of
|
# are not supported on TPU. They are solely for matching the signature of
|
||||||
# __new__ of parent class fc._EmbeddingColumn.
|
# __new__ of parent class fc._EmbeddingColumn.
|
||||||
@ -340,9 +361,11 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
|||||||
tensor_name_in_ckpt=None,
|
tensor_name_in_ckpt=None,
|
||||||
max_norm=None,
|
max_norm=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
max_sequence_length=0):
|
max_sequence_length=0,
|
||||||
|
partition_strategy='div'):
|
||||||
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
|
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
|
||||||
max_sequence_length=max_sequence_length)
|
max_sequence_length=max_sequence_length,
|
||||||
|
partition_strategy=partition_strategy)
|
||||||
self._key = None
|
self._key = None
|
||||||
|
|
||||||
def get_combiner(self):
|
def get_combiner(self):
|
||||||
@ -383,12 +406,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
|||||||
|
|
||||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||||
if tpu.under_tpu_inference_context():
|
if tpu.under_tpu_inference_context():
|
||||||
|
if self._partition_strategy == 'mod':
|
||||||
|
raise NotImplementedError('Export saved model does not support MOD '
|
||||||
|
'sharded embeddings.')
|
||||||
def host_computation():
|
def host_computation():
|
||||||
return fc._EmbeddingColumn._get_dense_tensor(
|
return fc._EmbeddingColumn._get_dense_tensor(
|
||||||
self, inputs, weight_collections, trainable)
|
self, inputs, weight_collections, trainable)
|
||||||
return tpu.outside_compilation(host_computation)
|
return tpu.outside_compilation(host_computation)
|
||||||
|
|
||||||
if _is_running_on_cpu():
|
if _is_running_on_cpu():
|
||||||
|
if self._partition_strategy == 'mod':
|
||||||
|
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
|
||||||
|
'sharded embeddings.')
|
||||||
return fc._EmbeddingColumn._get_dense_tensor(
|
return fc._EmbeddingColumn._get_dense_tensor(
|
||||||
self, inputs, weight_collections, trainable)
|
self, inputs, weight_collections, trainable)
|
||||||
|
|
||||||
@ -405,12 +434,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
|||||||
def _get_sequence_dense_tensor(
|
def _get_sequence_dense_tensor(
|
||||||
self, inputs, weight_collections=None, trainable=None):
|
self, inputs, weight_collections=None, trainable=None):
|
||||||
if tpu.under_tpu_inference_context():
|
if tpu.under_tpu_inference_context():
|
||||||
|
if self._partition_strategy == 'mod':
|
||||||
|
raise NotImplementedError('Export saved model does not support MOD '
|
||||||
|
'sharded embeddings.')
|
||||||
def host_computation():
|
def host_computation():
|
||||||
return fc._EmbeddingColumn._get_sequence_dense_tensor(
|
return fc._EmbeddingColumn._get_sequence_dense_tensor(
|
||||||
self, inputs, weight_collections, trainable)
|
self, inputs, weight_collections, trainable)
|
||||||
return tpu.outside_compilation(host_computation)
|
return tpu.outside_compilation(host_computation)
|
||||||
|
|
||||||
if _is_running_on_cpu():
|
if _is_running_on_cpu():
|
||||||
|
if self._partition_strategy == 'mod':
|
||||||
|
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
|
||||||
|
'sharded embeddings.')
|
||||||
return fc._EmbeddingColumn._get_sequence_dense_tensor(
|
return fc._EmbeddingColumn._get_sequence_dense_tensor(
|
||||||
self, inputs, weight_collections, trainable)
|
self, inputs, weight_collections, trainable)
|
||||||
|
|
||||||
@ -443,7 +478,8 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
|||||||
tensor_name_in_ckpt=None,
|
tensor_name_in_ckpt=None,
|
||||||
max_norm=None,
|
max_norm=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
max_sequence_length=0):
|
max_sequence_length=0,
|
||||||
|
partition_strategy='div'):
|
||||||
return fc._SharedEmbeddingColumn.__new__(
|
return fc._SharedEmbeddingColumn.__new__(
|
||||||
cls,
|
cls,
|
||||||
categorical_column,
|
categorical_column,
|
||||||
@ -466,10 +502,12 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
|||||||
tensor_name_in_ckpt=None,
|
tensor_name_in_ckpt=None,
|
||||||
max_norm=None,
|
max_norm=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
max_sequence_length=0):
|
max_sequence_length=0,
|
||||||
|
partition_strategy='div'):
|
||||||
|
|
||||||
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
|
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
|
||||||
max_sequence_length=max_sequence_length)
|
max_sequence_length=max_sequence_length,
|
||||||
|
partition_strategy=partition_strategy)
|
||||||
self._key = None
|
self._key = None
|
||||||
|
|
||||||
def get_combiner(self):
|
def get_combiner(self):
|
||||||
@ -510,12 +548,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
|||||||
|
|
||||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||||
if tpu.under_tpu_inference_context():
|
if tpu.under_tpu_inference_context():
|
||||||
|
if self._partition_strategy == 'mod':
|
||||||
|
raise NotImplementedError('Export saved model does not support MOD '
|
||||||
|
'sharded embeddings.')
|
||||||
def host_computation():
|
def host_computation():
|
||||||
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
||||||
self, inputs, weight_collections, trainable)
|
self, inputs, weight_collections, trainable)
|
||||||
return tpu.outside_compilation(host_computation)
|
return tpu.outside_compilation(host_computation)
|
||||||
|
|
||||||
if _is_running_on_cpu():
|
if _is_running_on_cpu():
|
||||||
|
if self._partition_strategy == 'mod':
|
||||||
|
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
|
||||||
|
'sharded embeddings.')
|
||||||
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
return fc._SharedEmbeddingColumn._get_dense_tensor(
|
||||||
self, inputs, weight_collections, trainable)
|
self, inputs, weight_collections, trainable)
|
||||||
|
|
||||||
@ -533,12 +577,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
|||||||
def _get_sequence_dense_tensor(
|
def _get_sequence_dense_tensor(
|
||||||
self, inputs, weight_collections=None, trainable=None):
|
self, inputs, weight_collections=None, trainable=None):
|
||||||
if tpu.under_tpu_inference_context():
|
if tpu.under_tpu_inference_context():
|
||||||
|
if self._partition_strategy == 'mod':
|
||||||
|
raise NotImplementedError('Export saved model does not support MOD '
|
||||||
|
'sharded embeddings.')
|
||||||
def host_computation():
|
def host_computation():
|
||||||
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
|
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
|
||||||
self, inputs, weight_collections, trainable)
|
self, inputs, weight_collections, trainable)
|
||||||
return tpu.outside_compilation(host_computation)
|
return tpu.outside_compilation(host_computation)
|
||||||
|
|
||||||
if _is_running_on_cpu():
|
if _is_running_on_cpu():
|
||||||
|
if self._partition_strategy == 'mod':
|
||||||
|
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
|
||||||
|
'sharded embeddings.')
|
||||||
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
|
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
|
||||||
self, inputs, weight_collections, trainable)
|
self, inputs, weight_collections, trainable)
|
||||||
|
|
||||||
|
@ -412,7 +412,8 @@ class TPUEmbedding(object):
|
|||||||
master,
|
master,
|
||||||
optimization_parameters=None,
|
optimization_parameters=None,
|
||||||
cluster_def=None,
|
cluster_def=None,
|
||||||
pipeline_execution_with_tensor_core=False):
|
pipeline_execution_with_tensor_core=False,
|
||||||
|
partition_strategy='div'):
|
||||||
"""API for using TPU for embedding lookups.
|
"""API for using TPU for embedding lookups.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -433,10 +434,18 @@ class TPUEmbedding(object):
|
|||||||
faster, but trained model will be different if step N and step N+1
|
faster, but trained model will be different if step N and step N+1
|
||||||
involve the same set of embedding IDs. Please see
|
involve the same set of embedding IDs. Please see
|
||||||
`tpu_embedding_configuration.proto` for details.
|
`tpu_embedding_configuration.proto` for details.
|
||||||
|
partition_strategy: A string, either 'mod' or 'div', specifying how to map
|
||||||
|
the lookup id to the embedding tensor. For more information see
|
||||||
|
`tf.nn.embedding_lookup_sparse`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if any input is invalid.
|
ValueError: if any input is invalid.
|
||||||
"""
|
"""
|
||||||
|
if partition_strategy not in ('div', 'mod'):
|
||||||
|
raise ValueError(
|
||||||
|
'Invalid partition_strategy {}'.format(partition_strategy))
|
||||||
|
self._partition_strategy = partition_strategy
|
||||||
|
|
||||||
_validate_table_to_config_dict(table_to_config_dict)
|
_validate_table_to_config_dict(table_to_config_dict)
|
||||||
# Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
|
# Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
|
||||||
self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
|
self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
|
||||||
@ -598,7 +607,10 @@ class TPUEmbedding(object):
|
|||||||
config_proto.batch_size_per_tensor_core = self._batch_size_per_core
|
config_proto.batch_size_per_tensor_core = self._batch_size_per_core
|
||||||
config_proto.num_hosts = self._num_hosts
|
config_proto.num_hosts = self._num_hosts
|
||||||
config_proto.num_tensor_cores = self._num_cores
|
config_proto.num_tensor_cores = self._num_cores
|
||||||
config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT
|
config_proto.sharding_strategy = (
|
||||||
|
elc.TPUEmbeddingConfiguration.DIV_DEFAULT
|
||||||
|
if self._partition_strategy == 'div' else
|
||||||
|
elc.TPUEmbeddingConfiguration.MOD)
|
||||||
config_proto.pipeline_execution_with_tensor_core = (
|
config_proto.pipeline_execution_with_tensor_core = (
|
||||||
self._pipeline_execution_with_tensor_core)
|
self._pipeline_execution_with_tensor_core)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user