Add check when user is using TPU embedding but not PER_HOST_V2 pipeline in TPUEstimator constructor.
PiperOrigin-RevId: 243929828
This commit is contained in:
parent
7d2fd55ef6
commit
0500369bf5
@ -2262,7 +2262,7 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
export_to_cpu=True,
|
||||
warm_start_from=None,
|
||||
experimental_export_device_assignment=False,
|
||||
experimental_embedding_config_spec=None):
|
||||
embedding_config_spec=None):
|
||||
"""Constructs an `TPUEstimator` instance.
|
||||
|
||||
Args:
|
||||
@ -2319,9 +2319,8 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
assignment in the exported model. Doing so is useful in case of model
|
||||
parallel inference but will tie the exported model to the TPU topology
|
||||
used to export the model.
|
||||
experimental_embedding_config_spec: Optional EmbeddingConfigSpec instance
|
||||
to support using TPU embedding. IT IS STILL WORK IN PROGRESS, SO PLEASE
|
||||
DO NOT USE.
|
||||
embedding_config_spec: Optional EmbeddingConfigSpec instance
|
||||
to support using TPU embedding.
|
||||
|
||||
Raises:
|
||||
ValueError: `params` has reserved keys already.
|
||||
@ -2355,6 +2354,13 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
util_lib.check_positive_integer(predict_batch_size,
|
||||
'predict_batch_size')
|
||||
|
||||
if embedding_config_spec:
|
||||
if (config.tpu_config.per_host_input_for_training !=
|
||||
tpu_config.InputPipelineConfig.PER_HOST_V2):
|
||||
raise ValueError('Only PER_HOST_V2 is supported when using TPU '
|
||||
'Embedding; got {}.'.format(
|
||||
config.tpu_config.per_host_input_for_training))
|
||||
|
||||
# Verifies the model_fn signature according to Estimator framework.
|
||||
estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access
|
||||
# We cannot store config and params in this constructor as parent
|
||||
@ -2383,7 +2389,7 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||
# pylint: disable=protected-access
|
||||
self._ctx = tpu_context._get_tpu_context(
|
||||
self._config, train_batch_size, eval_batch_size, predict_batch_size,
|
||||
use_tpu, eval_on_tpu, experimental_embedding_config_spec)
|
||||
use_tpu, eval_on_tpu, embedding_config_spec)
|
||||
|
||||
self._export_to_cpu = export_to_cpu
|
||||
self._export_to_tpu = export_to_tpu
|
||||
|
Loading…
Reference in New Issue
Block a user