diff --git a/tensorflow/python/tpu/feature_column.py b/tensorflow/python/tpu/feature_column.py index 5f9535dc3c6..6039a57ce90 100644 --- a/tensorflow/python/tpu/feature_column.py +++ b/tensorflow/python/tpu/feature_column.py @@ -86,7 +86,10 @@ def embedding_column(categorical_column, and any sequence longer will be truncated. This must be positive for sequence features and 0 for non-sequence features. learning_rate_fn: A function that takes global step and returns learning - rate for the embedding table. + rate for the embedding table. If you intend to use the same learning rate + for multiple embedding tables, please ensure that you pass the exact same + python function to all calls of embedding_column, otherwise performence + may suffer. use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures there are no empty rows and all weights and ids are positive at the @@ -196,7 +199,10 @@ def shared_embedding_columns(categorical_columns, sequence shorter then this will be padded with 0 embeddings and any sequence longer will be truncated. learning_rate_fn: A function that takes global step and returns learning - rate for the embedding table. + rate for the embedding table. If you intend to use the same learning rate + for multiple embedding tables, please ensure that you pass the exact same + python function to all calls of shared_embedding_columns, otherwise + performence may suffer. use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures there are no empty rows and all weights and ids are positive at the diff --git a/tensorflow/python/tpu/feature_column_v2.py b/tensorflow/python/tpu/feature_column_v2.py index a51a3153c76..d9820425467 100644 --- a/tensorflow/python/tpu/feature_column_v2.py +++ b/tensorflow/python/tpu/feature_column_v2.py @@ -107,7 +107,10 @@ def embedding_column_v2(categorical_column, and any sequence longer will be truncated. This must be positive for sequence features and 0 for non-sequence features. learning_rate_fn: A function that takes global step and returns learning - rate for the embedding table. + rate for the embedding table. If you intend to use the same learning rate + for multiple embedding tables, please ensure that you pass the exact same + python function to all calls of embedding_column, otherwise performence + may suffer. embedding_lookup_device: The device on which to run the embedding lookup. Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". If specifying "tpu_tensor_core", a tensor_core_shape must be supplied. @@ -266,7 +269,10 @@ def shared_embedding_columns_v2(categorical_columns, sequence shorter then this will be padded with 0 embeddings and any sequence longer will be truncated. learning_rate_fn: A function that takes global step and returns learning - rate for the embedding table. + rate for the embedding table. If you intend to use the same learning rate + for multiple embedding tables, please ensure that you pass the exact same + python function to all calls of shared_embedding_columns, otherwise + performence may suffer. embedding_lookup_device: The device on which to run the embedding lookup. Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". If specifying "tpu_tensor_core", a tensor_core_shape must be supplied.