diff --git a/tensorflow/python/tpu/feature_column.py b/tensorflow/python/tpu/feature_column.py
index 5f9535dc3c6..6039a57ce90 100644
--- a/tensorflow/python/tpu/feature_column.py
+++ b/tensorflow/python/tpu/feature_column.py
@@ -86,7 +86,10 @@ def embedding_column(categorical_column,
       and any sequence longer will be truncated. This must be positive for
       sequence features and 0 for non-sequence features.
     learning_rate_fn: A function that takes global step and returns learning
-      rate for the embedding table.
+      rate for the embedding table. If you intend to use the same learning rate
+      for multiple embedding tables, please ensure that you pass the exact same
+      python function to all calls of embedding_column, otherwise performence
+      may suffer.
     use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
       instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
       there are no empty rows and all weights and ids are positive at the
@@ -196,7 +199,10 @@ def shared_embedding_columns(categorical_columns,
       sequence shorter then this will be padded with 0 embeddings and any
       sequence longer will be truncated.
     learning_rate_fn: A function that takes global step and returns learning
-      rate for the embedding table.
+      rate for the embedding table. If you intend to use the same learning rate
+      for multiple embedding tables, please ensure that you pass the exact same
+      python function to all calls of shared_embedding_columns, otherwise
+      performence may suffer.
     use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
       instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
       there are no empty rows and all weights and ids are positive at the
diff --git a/tensorflow/python/tpu/feature_column_v2.py b/tensorflow/python/tpu/feature_column_v2.py
index a51a3153c76..d9820425467 100644
--- a/tensorflow/python/tpu/feature_column_v2.py
+++ b/tensorflow/python/tpu/feature_column_v2.py
@@ -107,7 +107,10 @@ def embedding_column_v2(categorical_column,
       and any sequence longer will be truncated. This must be positive for
       sequence features and 0 for non-sequence features.
     learning_rate_fn: A function that takes global step and returns learning
-      rate for the embedding table.
+      rate for the embedding table. If you intend to use the same learning rate
+      for multiple embedding tables, please ensure that you pass the exact same
+      python function to all calls of embedding_column, otherwise performence
+      may suffer.
     embedding_lookup_device: The device on which to run the embedding lookup.
       Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core".
       If specifying "tpu_tensor_core", a tensor_core_shape must be supplied.
@@ -266,7 +269,10 @@ def shared_embedding_columns_v2(categorical_columns,
       sequence shorter then this will be padded with 0 embeddings and any
       sequence longer will be truncated.
     learning_rate_fn: A function that takes global step and returns learning
-      rate for the embedding table.
+      rate for the embedding table. If you intend to use the same learning rate
+      for multiple embedding tables, please ensure that you pass the exact same
+      python function to all calls of shared_embedding_columns, otherwise
+      performence may suffer.
     embedding_lookup_device: The device on which to run the embedding lookup.
       Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". If
       specifying "tpu_tensor_core", a tensor_core_shape must be supplied.