diff --git a/tensorflow/python/tpu/tpu_estimator.py b/tensorflow/python/tpu/tpu_estimator.py index dcad6ebb2dc..a84be3c833f 100644 --- a/tensorflow/python/tpu/tpu_estimator.py +++ b/tensorflow/python/tpu/tpu_estimator.py @@ -2262,7 +2262,7 @@ class TPUEstimator(estimator_lib.Estimator): export_to_cpu=True, warm_start_from=None, experimental_export_device_assignment=False, - experimental_embedding_config_spec=None): + embedding_config_spec=None): """Constructs an `TPUEstimator` instance. Args: @@ -2319,9 +2319,8 @@ class TPUEstimator(estimator_lib.Estimator): assignment in the exported model. Doing so is useful in case of model parallel inference but will tie the exported model to the TPU topology used to export the model. - experimental_embedding_config_spec: Optional EmbeddingConfigSpec instance - to support using TPU embedding. IT IS STILL WORK IN PROGRESS, SO PLEASE - DO NOT USE. + embedding_config_spec: Optional EmbeddingConfigSpec instance + to support using TPU embedding. Raises: ValueError: `params` has reserved keys already. @@ -2355,6 +2354,13 @@ class TPUEstimator(estimator_lib.Estimator): util_lib.check_positive_integer(predict_batch_size, 'predict_batch_size') + if embedding_config_spec: + if (config.tpu_config.per_host_input_for_training != + tpu_config.InputPipelineConfig.PER_HOST_V2): + raise ValueError('Only PER_HOST_V2 is supported when using TPU ' + 'Embedding; got {}.'.format( + config.tpu_config.per_host_input_for_training)) + # Verifies the model_fn signature according to Estimator framework. estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access # We cannot store config and params in this constructor as parent @@ -2383,7 +2389,7 @@ class TPUEstimator(estimator_lib.Estimator): # pylint: disable=protected-access self._ctx = tpu_context._get_tpu_context( self._config, train_batch_size, eval_batch_size, predict_batch_size, - use_tpu, eval_on_tpu, experimental_embedding_config_spec) + use_tpu, eval_on_tpu, embedding_config_spec) self._export_to_cpu = export_to_cpu self._export_to_tpu = export_to_tpu