diff --git a/tensorflow/python/tpu/feature_column.py b/tensorflow/python/tpu/feature_column.py index 39a1307a83c..57eb9dda3cc 100644 --- a/tensorflow/python/tpu/feature_column.py +++ b/tensorflow/python/tpu/feature_column.py @@ -50,7 +50,8 @@ def embedding_column(categorical_column, dimension, combiner='mean', initializer=None, - max_sequence_length=0): + max_sequence_length=0, + partition_strategy='div'): """TPU embedding_column for `tf.feature_column.embedding_column`. Note that the interface for TPU embedding_column is different from the non-TPU @@ -77,6 +78,11 @@ def embedding_column(categorical_column, length. Any sequence shorter then this will be padded with 0 embeddings and any sequence longer will be truncated. This must be positive for sequence features and 0 for non-sequence features. + partition_strategy: Determines how tensors are sharded on the tpu hosts. See + `tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are + `"div"` and `"mod"'. If `"mod"` is used, evaluation and exporting the + model to CPU will not work. In order to do this, you must shuffle the + embedding tensors into a single shard. Returns: A _TPUEmbeddingColumn. @@ -122,7 +128,8 @@ def embedding_column(categorical_column, tensor_name_in_ckpt=None, max_norm=None, trainable=True, - max_sequence_length=max_sequence_length) + max_sequence_length=max_sequence_length, + partition_strategy=partition_strategy) # For Embedding column, the initializer is hidden inside the creator Fn, which # is not accessiable later. So, we attach it to a speicial field. Also note # that non-TPU Embedding column and non-TPU shared Embedding column handle the @@ -136,7 +143,8 @@ def shared_embedding_columns(categorical_columns, combiner='mean', initializer=None, shared_embedding_collection_name=None, - max_sequence_lengths=None): + max_sequence_lengths=None, + partition_strategy='div'): """List of dense columns that convert from sparse, categorical input. Note that the interface for TPU embedding_column is different from the non-TPU @@ -169,6 +177,9 @@ def shared_embedding_columns(categorical_columns, to sequence columns specify the max sequence length for the column. Any sequence shorter then this will be padded with 0 embeddings and any sequence longer will be truncated. + partition_strategy: Determines how tensors are sharded on the tpu hosts. See + `tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are + `"div"` and `"mod"'. Returns: A _TPUEmbeddingColumn. @@ -238,7 +249,8 @@ def shared_embedding_columns(categorical_columns, tensor_name_in_ckpt=None, max_norm=None, trainable=True, - max_sequence_length=max_sequence_length) + max_sequence_length=max_sequence_length, + partition_strategy=partition_strategy) tpu_columns.append(column) return tpu_columns @@ -247,7 +259,8 @@ def shared_embedding_columns(categorical_columns, class _TPUBaseEmbeddingColumn(object): """Base class for TPU Embedding Column.""" - def __init__(self, categorical_column, max_sequence_length=0): + def __init__(self, categorical_column, max_sequence_length=0, + partition_strategy='div'): self._tpu_categorical_column = categorical_column self._max_sequence_length = max_sequence_length if (self.is_sequence_column() and max_sequence_length < 1): @@ -259,6 +272,10 @@ class _TPUBaseEmbeddingColumn(object): raise ValueError('Non zero max_seq_length={} specified for non ' 'sequence column {}.'.format(max_sequence_length, categorical_column.name)) + self._partition_strategy = partition_strategy + if partition_strategy not in ('mod', 'div'): + raise ValueError('partition_strategy must be one of `mod` or `div`. ' + 'Received {}.'.format(partition_strategy)) def get_combiner(self): """Returns the embedding combiner.""" @@ -303,6 +320,9 @@ class _TPUBaseEmbeddingColumn(object): return get_sequence_length_feature_key_name_from_feature_key_name( self.get_feature_key_name()) + def get_partition_strategy(self): + return self._partition_strategy + class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): """Core Embedding Column.""" @@ -316,7 +336,8 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): tensor_name_in_ckpt=None, max_norm=None, trainable=True, - max_sequence_length=0): + max_sequence_length=0, + partition_strategy='div'): # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable # are not supported on TPU. They are solely for matching the signature of # __new__ of parent class fc._EmbeddingColumn. @@ -340,9 +361,11 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): tensor_name_in_ckpt=None, max_norm=None, trainable=True, - max_sequence_length=0): + max_sequence_length=0, + partition_strategy='div'): _TPUBaseEmbeddingColumn.__init__(self, categorical_column, - max_sequence_length=max_sequence_length) + max_sequence_length=max_sequence_length, + partition_strategy=partition_strategy) self._key = None def get_combiner(self): @@ -383,12 +406,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): + if self._partition_strategy == 'mod': + raise NotImplementedError('Export saved model does not support MOD ' + 'sharded embeddings.') def host_computation(): return fc._EmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): + if self._partition_strategy == 'mod': + raise NotImplementedError('TPUEmbedding on CPU does not support MOD ' + 'sharded embeddings.') return fc._EmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) @@ -405,12 +434,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): def _get_sequence_dense_tensor( self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): + if self._partition_strategy == 'mod': + raise NotImplementedError('Export saved model does not support MOD ' + 'sharded embeddings.') def host_computation(): return fc._EmbeddingColumn._get_sequence_dense_tensor( self, inputs, weight_collections, trainable) return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): + if self._partition_strategy == 'mod': + raise NotImplementedError('TPUEmbedding on CPU does not support MOD ' + 'sharded embeddings.') return fc._EmbeddingColumn._get_sequence_dense_tensor( self, inputs, weight_collections, trainable) @@ -443,7 +478,8 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, tensor_name_in_ckpt=None, max_norm=None, trainable=True, - max_sequence_length=0): + max_sequence_length=0, + partition_strategy='div'): return fc._SharedEmbeddingColumn.__new__( cls, categorical_column, @@ -466,10 +502,12 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, tensor_name_in_ckpt=None, max_norm=None, trainable=True, - max_sequence_length=0): + max_sequence_length=0, + partition_strategy='div'): _TPUBaseEmbeddingColumn.__init__(self, categorical_column, - max_sequence_length=max_sequence_length) + max_sequence_length=max_sequence_length, + partition_strategy=partition_strategy) self._key = None def get_combiner(self): @@ -510,12 +548,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): + if self._partition_strategy == 'mod': + raise NotImplementedError('Export saved model does not support MOD ' + 'sharded embeddings.') def host_computation(): return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): + if self._partition_strategy == 'mod': + raise NotImplementedError('TPUEmbedding on CPU does not support MOD ' + 'sharded embeddings.') return fc._SharedEmbeddingColumn._get_dense_tensor( self, inputs, weight_collections, trainable) @@ -533,12 +577,18 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, def _get_sequence_dense_tensor( self, inputs, weight_collections=None, trainable=None): if tpu.under_tpu_inference_context(): + if self._partition_strategy == 'mod': + raise NotImplementedError('Export saved model does not support MOD ' + 'sharded embeddings.') def host_computation(): return fc._SharedEmbeddingColumn._get_sequence_dense_tensor( self, inputs, weight_collections, trainable) return tpu.outside_compilation(host_computation) if _is_running_on_cpu(): + if self._partition_strategy == 'mod': + raise NotImplementedError('TPUEmbedding on CPU does not support MOD ' + 'sharded embeddings.') return fc._SharedEmbeddingColumn._get_sequence_dense_tensor( self, inputs, weight_collections, trainable) diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index 1d3b7cf9256..81d3da1bc68 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -412,7 +412,8 @@ class TPUEmbedding(object): master, optimization_parameters=None, cluster_def=None, - pipeline_execution_with_tensor_core=False): + pipeline_execution_with_tensor_core=False, + partition_strategy='div'): """API for using TPU for embedding lookups. Args: @@ -433,10 +434,18 @@ class TPUEmbedding(object): faster, but trained model will be different if step N and step N+1 involve the same set of embedding IDs. Please see `tpu_embedding_configuration.proto` for details. + partition_strategy: A string, either 'mod' or 'div', specifying how to map + the lookup id to the embedding tensor. For more information see + `tf.nn.embedding_lookup_sparse`. Raises: ValueError: if any input is invalid. """ + if partition_strategy not in ('div', 'mod'): + raise ValueError( + 'Invalid partition_strategy {}'.format(partition_strategy)) + self._partition_strategy = partition_strategy + _validate_table_to_config_dict(table_to_config_dict) # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) @@ -598,7 +607,10 @@ class TPUEmbedding(object): config_proto.batch_size_per_tensor_core = self._batch_size_per_core config_proto.num_hosts = self._num_hosts config_proto.num_tensor_cores = self._num_cores - config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT + config_proto.sharding_strategy = ( + elc.TPUEmbeddingConfiguration.DIV_DEFAULT + if self._partition_strategy == 'div' else + elc.TPUEmbeddingConfiguration.MOD) config_proto.pipeline_execution_with_tensor_core = ( self._pipeline_execution_with_tensor_core)