Add check when user is using TPU embedding but not PER_HOST_V2 pipeline in TPUEstimator constructor.

PiperOrigin-RevId: 243929828
This commit is contained in:
A. Unique TensorFlower 2019-04-16 20:47:43 -07:00 committed by TensorFlower Gardener
parent 7d2fd55ef6
commit 0500369bf5

View File

@ -2262,7 +2262,7 @@ class TPUEstimator(estimator_lib.Estimator):
export_to_cpu=True,
warm_start_from=None,
experimental_export_device_assignment=False,
experimental_embedding_config_spec=None):
embedding_config_spec=None):
"""Constructs an `TPUEstimator` instance.
Args:
@ -2319,9 +2319,8 @@ class TPUEstimator(estimator_lib.Estimator):
assignment in the exported model. Doing so is useful in case of model
parallel inference but will tie the exported model to the TPU topology
used to export the model.
experimental_embedding_config_spec: Optional EmbeddingConfigSpec instance
to support using TPU embedding. IT IS STILL WORK IN PROGRESS, SO PLEASE
DO NOT USE.
embedding_config_spec: Optional EmbeddingConfigSpec instance
to support using TPU embedding.
Raises:
ValueError: `params` has reserved keys already.
@ -2355,6 +2354,13 @@ class TPUEstimator(estimator_lib.Estimator):
util_lib.check_positive_integer(predict_batch_size,
'predict_batch_size')
if embedding_config_spec:
if (config.tpu_config.per_host_input_for_training !=
tpu_config.InputPipelineConfig.PER_HOST_V2):
raise ValueError('Only PER_HOST_V2 is supported when using TPU '
'Embedding; got {}.'.format(
config.tpu_config.per_host_input_for_training))
# Verifies the model_fn signature according to Estimator framework.
estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access
# We cannot store config and params in this constructor as parent
@ -2383,7 +2389,7 @@ class TPUEstimator(estimator_lib.Estimator):
# pylint: disable=protected-access
self._ctx = tpu_context._get_tpu_context(
self._config, train_batch_size, eval_batch_size, predict_batch_size,
use_tpu, eval_on_tpu, experimental_embedding_config_spec)
use_tpu, eval_on_tpu, embedding_config_spec)
self._export_to_cpu = export_to_cpu
self._export_to_tpu = export_to_tpu