Export TPU Embedding related symbols to their respective name spaces (tf.tpu and tf.estimator.tpu).

PiperOrigin-RevId: 254099024
This commit is contained in:
Bruce Fontaine 2019-06-19 16:51:29 -07:00 committed by TensorFlower Gardener
parent 648145bfd7
commit a94695b4d1
13 changed files with 284 additions and 66 deletions

View File

@ -100,9 +100,7 @@ from tensorflow.python.ops.signal import signal
from tensorflow.python.profiler import profiler
from tensorflow.python.saved_model import saved_model
from tensorflow.python.summary import summary
from tensorflow.python.tpu import bfloat16 as _
from tensorflow.python.tpu import tpu as _
from tensorflow.python.tpu import tpu_optimizer as _
from tensorflow.python.tpu import api
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat

View File

@ -140,6 +140,7 @@ py_library(
name = "tpu_noestimator",
srcs = [
"__init__.py",
"api.py",
],
srcs_version = "PY2AND3",
deps = [

View File

@ -0,0 +1,31 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Modules that need to be exported to the API.
List TPU modules that aren't included elsewhere here so that they can be scanned
for tf_export decorations.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.tpu import bfloat16
from tensorflow.python.tpu import feature_column_v2
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_embedding
from tensorflow.python.tpu import tpu_optimizer
# pylint: enable=unused-import

View File

@ -28,32 +28,52 @@ from tensorflow.python.tpu.feature_column import _is_running_on_cpu
from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name
from tensorflow.python.tpu.feature_column import _SUPPORTED_CATEGORICAL_COLUMNS_V2
from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
@tf_export(v1=['tpu.experimental.embedding_column'])
def embedding_column_v2(categorical_column,
dimension,
combiner='mean',
initializer=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=0):
"""TPU embedding_column for `tf.feature_column.embedding_column`.
"""TPU version of `tf.compat.v1.feature_column.embedding_column`.
Note that the interface for TPU embedding_column is different from the non-TPU
version. The following args available for the non-TPU version are NOT
supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
Note that the interface for `tf.tpu.experimental.embedding_column` is
different from that of `tf.compat.v1.feature_column.embedding_column`: The
following arguments are NOT supported: `ckpt_to_load_from`,
`tensor_name_in_ckpt`, `max_norm` and `trainable`.
Use this function in place of `tf.compat.v1.feature_column.embedding_column`
when you want to use the TPU to accelerate your embedding lookups via TPU
embeddings.
```
column = tf.feature_column.categorical_column_with_identity(...)
tpu_column = tf.tpu.experimental.embedding_column(column, 10)
...
def model_fn(features):
dense_feature = tf.keras.layers.DenseFeature(tpu_column)
embedded_feature = dense_feature(features)
...
estimator = tf.estimator.tpu.TPUEstimator(
model_fn=model_fn,
...
embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
column=[tpu_column],
...))
```
Args:
categorical_column: A categorical_column returned from
categorical_column_with_identity, weighted_categorical_column,
categorical_column_with_vocabulary_file,
categorical_column_with_vocabulary_list,
sequence_categorical_column_with_identity,
sequence_categorical_column_with_vocabulary_file,
sequence_categorical_column_with_vocabulary_list
categorical_column: A categorical column returned from
`categorical_column_with_identity`, `weighted_categorical_column`,
`categorical_column_with_vocabulary_file`,
`categorical_column_with_vocabulary_list`,
`sequence_categorical_column_with_identity`,
`sequence_categorical_column_with_vocabulary_file`,
`sequence_categorical_column_with_vocabulary_list`
dimension: An integer specifying dimension of the embedding, must be > 0.
combiner: A string specifying how to reduce if there are multiple entries
in a single row for a non-sequence column. For more information, see
@ -62,31 +82,18 @@ def embedding_column_v2(categorical_column,
variable initialization. If not specified, defaults to
`tf.compat.v1.truncated_normal_initializer` with mean `0.0` and
standard deviation `1/sqrt(dimension)`.
ckpt_to_load_from: Argument not used for TPU.
tensor_name_in_ckpt: Argument not used for TPU.
max_norm: Argument not used for TPU.
trainable: Argument not used for TPU.
max_sequence_length: An non-negative integer specifying the max sequence
length. Any sequence shorter then this will be padded with 0 embeddings
and any sequence longer will be truncated. This must be positive for
sequence features and 0 for non-sequence features.
Returns:
A _TPUEmbeddingColumnV2.
A `_TPUEmbeddingColumnV2`.
Raises:
ValueError: if `dimension` not > 0.
ValueError: if `initializer` is specified but not callable.
"""
if not (ckpt_to_load_from is None and tensor_name_in_ckpt is None):
raise ValueError('ckpt_to_load_from, tensor_name_in_ckpt are not '
'supported for TPU Embeddings. To load a embedding '
'table from a different checkpoint, use a scaffold_fn '
'and tf.train.init_from_checkpoint.')
if max_norm is not None:
raise ValueError('max_norm is not support for TPU Embeddings.')
if not trainable:
raise ValueError('TPU Embeddings do not support non-trainable weights.')
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2):
raise TypeError(
@ -114,30 +121,51 @@ def embedding_column_v2(categorical_column,
return column
@tf_export(v1=['tpu.experimental.shared_embedding_columns'])
def shared_embedding_columns_v2(categorical_columns,
dimension,
combiner='mean',
initializer=None,
shared_embedding_collection_name=None,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_lengths=None):
"""List of dense columns that convert from sparse, categorical input.
"""TPU version of `tf.compat.v1.feature_column.shared_embedding_columns`.
Note that the interface for TPU embedding_column is different from the non-TPU
version. The following args available for the non-TPU version are NOT
supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
Note that the interface for `tf.tpu.experimental.shared_embedding_columns` is
different from that of `tf.compat.v1.feature_column.shared_embedding_columns`:
The following arguments are NOT supported: `ckpt_to_load_from`,
`tensor_name_in_ckpt`, `max_norm` and `trainable`.
Use this function in place of
tf.compat.v1.feature_column.shared_embedding_columns` when you want to use the
TPU to accelerate your embedding lookups via TPU embeddings.
```
column_a = tf.feature_column.categorical_column_with_identity(...)
column_b = tf.feature_column.categorical_column_with_identity(...)
tpu_columns = tf.tpu.experimental.shared_embedding_columns(
[column_a, column_b], 10)
...
def model_fn(features):
dense_feature = tf.keras.layers.DenseFeature(tpu_columns)
embedded_feature = dense_feature(features)
...
estimator = tf.estimator.tpu.TPUEstimator(
model_fn=model_fn,
...
embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
column=tpu_columns,
...))
```
Args:
categorical_columns: A list of categorical_columns returned from
categorical_column_with_identity, weighted_categorical_column,
categorical_column_with_vocabulary_file,
categorical_column_with_vocabulary_list,
sequence_categorical_column_with_identity,
sequence_categorical_column_with_vocabulary_file,
sequence_categorical_column_with_vocabulary_list
categorical_columns: A list of categorical columns returned from
`categorical_column_with_identity`, `weighted_categorical_column`,
`categorical_column_with_vocabulary_file`,
`categorical_column_with_vocabulary_list`,
`sequence_categorical_column_with_identity`,
`sequence_categorical_column_with_vocabulary_file`,
`sequence_categorical_column_with_vocabulary_list`
dimension: An integer specifying dimension of the embedding, must be > 0.
combiner: A string specifying how to reduce if there are multiple entries
in a single row for a non-sequence column. For more information, see
@ -150,10 +178,6 @@ def shared_embedding_columns_v2(categorical_columns,
shared embedding weights are added. If not given, a reasonable name will
be chosen based on the names of `categorical_columns`. This is also used
in `variable_scope` when creating shared embedding weights.
ckpt_to_load_from: Argument not used for TPU.
tensor_name_in_ckpt: Argument not used for TPU.
max_norm: Argument not used for TPU.
trainable: Argument not used for TPU.
max_sequence_lengths: An list of non-negative integers, either None or
empty or the same length as the argument categorical_columns. Entries
corresponding to non-sequence columns must be 0 and entries corresponding
@ -162,7 +186,7 @@ def shared_embedding_columns_v2(categorical_columns,
sequence longer will be truncated.
Returns:
A _TPUSharedEmbeddingColumnV2.
A list of `_TPUSharedEmbeddingColumnV2`.
Raises:
ValueError: if `dimension` not > 0.
@ -172,15 +196,6 @@ def shared_embedding_columns_v2(categorical_columns,
ValueError: if `max_sequence_lengths` is positive for a non sequence column
or 0 for a sequence column.
"""
if not (ckpt_to_load_from is None and tensor_name_in_ckpt is None):
raise ValueError('ckpt_to_load_from, tensor_name_in_ckpt are not '
'supported for TPU Embeddings. To load a embedding '
'table from a different checkpoint, use a scaffold_fn '
'and tf.train.init_from_checkpoint.')
if max_norm is not None:
raise ValueError('max_norm is not support for TPU Embeddings.')
if not trainable:
raise ValueError('TPU Embeddings do not support non-trainable weights.')
for categorical_column in categorical_columns:
if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2):
@ -227,8 +242,9 @@ def shared_embedding_columns_v2(categorical_columns,
tpu_columns = []
column_creator = fc_lib.SharedEmbeddingColumnCreator(
dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt,
num_buckets, trainable, shared_embedding_collection_name)
dimension=dimension, initializer=initializer, ckpt_to_load_from=None,
tensor_name_in_ckpt=None, num_buckets=num_buckets, trainable=None,
name=shared_embedding_collection_name)
# Create the state (_SharedEmbeddingColumnLayer) here.
for categorical_column, max_sequence_length in zip(

View File

@ -37,6 +37,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.util.tf_export import tf_export
TRAINING = elc.TPUEmbeddingConfiguration.TRAINING
INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
@ -228,8 +229,25 @@ class _OptimizationParameters(object):
self.clip_weight_max = clip_weight_max
@tf_export(v1=['tpu.experimental.AdagradParameters'])
class AdagradParameters(_OptimizationParameters):
"""Optimization parameters for Adagrad."""
"""Optimization parameters for Adagrad with TPU embeddings.
Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
`optimization_parameters` argument to set the optimizer and its parameters.
See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
for more details.
```
estimator = tf.estimator.tpu.TPUEstimator(
...
embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
...
optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1),
...))
```
"""
def __init__(self,
learning_rate,
@ -257,8 +275,25 @@ class AdagradParameters(_OptimizationParameters):
self.initial_accumulator = initial_accumulator
@tf_export(v1=['tpu.experimental.AdamParameters'])
class AdamParameters(_OptimizationParameters):
"""Optimization parameters for Adam."""
"""Optimization parameters for Adam with TPU embeddings.
Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
`optimization_parameters` argument to set the optimizer and its parameters.
See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
for more details.
```
estimator = tf.estimator.tpu.TPUEstimator(
...
embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
...
optimization_parameters=tf.tpu.experimental.AdamParameters(0.1),
...))
```
"""
def __init__(self,
learning_rate,
@ -310,8 +345,25 @@ class AdamParameters(_OptimizationParameters):
self.sum_inside_sqrt = sum_inside_sqrt
@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters'])
class StochasticGradientDescentParameters(_OptimizationParameters):
"""Optimization parameters for stochastic gradient descent."""
"""Optimization parameters for stochastic gradient descent for TPU embeddings.
Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
`optimization_parameters` argument to set the optimizer and its parameters.
See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
for more details.
```
estimator = tf.estimator.tpu.TPUEstimator(
...
embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
...
optimization_parameters=(
tf.tpu.experimental.StochasticGradientDescentParameters(0.1))))
```
"""
def __init__(self, learning_rate, clip_weight_min=None,
clip_weight_max=None):

View File

@ -0,0 +1,47 @@
path: "tensorflow.estimator.tpu.experimental.EmbeddingConfigSpec"
tf_class {
is_instance: "<class \'tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding.EmbeddingConfigSpec\'>"
is_instance: "<class \'tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding.EmbeddingConfigSpec\'>"
is_instance: "<type \'tuple\'>"
member {
name: "clipping_limit"
mtype: "<type \'property\'>"
}
member {
name: "experimental_gradient_multiplier_fn"
mtype: "<type \'property\'>"
}
member {
name: "feature_columns"
mtype: "<type \'property\'>"
}
member {
name: "feature_to_config_dict"
mtype: "<type \'property\'>"
}
member {
name: "optimization_parameters"
mtype: "<type \'property\'>"
}
member {
name: "partition_strategy"
mtype: "<type \'property\'>"
}
member {
name: "pipeline_execution_with_tensor_core"
mtype: "<type \'property\'>"
}
member {
name: "table_to_config_dict"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.estimator.tpu.experimental"
tf_module {
member {
name: "EmbeddingConfigSpec"
mtype: "<type \'type\'>"
}
}

View File

@ -20,4 +20,8 @@ tf_module {
name: "TPUEstimatorSpec"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}

View File

@ -0,0 +1,10 @@
path: "tensorflow.tpu.experimental.AdagradParameters"
tf_class {
is_instance: "<class \'tensorflow.python.tpu.tpu_embedding.AdagradParameters\'>"
is_instance: "<class \'tensorflow.python.tpu.tpu_embedding._OptimizationParameters\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'0.1\', \'True\', \'None\', \'None\'], "
}
}

View File

@ -0,0 +1,10 @@
path: "tensorflow.tpu.experimental.AdamParameters"
tf_class {
is_instance: "<class \'tensorflow.python.tpu.tpu_embedding.AdamParameters\'>"
is_instance: "<class \'tensorflow.python.tpu.tpu_embedding._OptimizationParameters\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.999\', \'1e-08\', \'True\', \'True\', \'True\', \'None\', \'None\'], "
}
}

View File

@ -0,0 +1,10 @@
path: "tensorflow.tpu.experimental.StochasticGradientDescentParameters"
tf_class {
is_instance: "<class \'tensorflow.python.tpu.tpu_embedding.StochasticGradientDescentParameters\'>"
is_instance: "<class \'tensorflow.python.tpu.tpu_embedding._OptimizationParameters\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
}

View File

@ -1,11 +1,31 @@
path: "tensorflow.tpu.experimental"
tf_module {
member {
name: "AdagradParameters"
mtype: "<type \'type\'>"
}
member {
name: "AdamParameters"
mtype: "<type \'type\'>"
}
member {
name: "DeviceAssignment"
mtype: "<type \'type\'>"
}
member {
name: "StochasticGradientDescentParameters"
mtype: "<type \'type\'>"
}
member_method {
name: "embedding_column"
argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'max_sequence_length\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'0\'], "
}
member_method {
name: "initialize_tpu_system"
argspec: "args=[\'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "shared_embedding_columns"
argspec: "args=[\'categorical_columns\', \'dimension\', \'combiner\', \'initializer\', \'shared_embedding_collection_name\', \'max_sequence_lengths\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\'], "
}
}

View File

@ -359,6 +359,8 @@ renames = {
'tf.compat.v1.estimator.tpu.TPUEstimator',
'tf.estimator.tpu.TPUEstimatorSpec':
'tf.compat.v1.estimator.tpu.TPUEstimatorSpec',
'tf.estimator.tpu.experimental.EmbeddingSpec':
'tf.compat.v1.estimator.tpu.experimental.EmbeddingSpec',
'tf.expm1':
'tf.math.expm1',
'tf.fake_quant_with_min_max_args':
@ -1319,6 +1321,16 @@ renames = {
'tf.compat.v1.tpu.core',
'tf.tpu.cross_replica_sum':
'tf.compat.v1.tpu.cross_replica_sum',
'tf.tpu.experimental.AdagradParameters':
'tf.compat.v1.tpu.experimental.AdagradParameters',
'tf.tpu.experimental.AdamParameters':
'tf.compat.v1.tpu.experimental.AdamParameters',
'tf.tpu.experimental.StochasticGradientDescentParameters':
'tf.compat.v1.tpu.experimental.StochasticGradientDescentParameters',
'tf.tpu.experimental.embedding_column':
'tf.compat.v1.tpu.experimental.embedding_column',
'tf.tpu.experimental.shared_embedding_columns':
'tf.compat.v1.tpu.experimental.shared_embedding_columns',
'tf.tpu.initialize_system':
'tf.compat.v1.tpu.initialize_system',
'tf.tpu.outside_compilation':