Adds a reminder to set export_saved_model_api_version=ExportSavedModelApiVersion.V2 when using TPUEstimator's inference_on_tpu.

PiperOrigin-RevId: 345329997
Change-Id: I00b859503a09ce1bab50afc223318afb05d563bf
This commit is contained in:
A. Unique TensorFlower 2020-12-02 15:51:45 -08:00 committed by TensorFlower Gardener
parent effe8fd10c
commit 254c5b9da2

View File

@ -45,8 +45,14 @@ _current_tpu_context = TpuContext()
@contextlib.contextmanager
def tpu_shard_context(number_of_shards):
"""A context manager setting current number of shards."""
if _current_tpu_context.number_of_shards is not None:
raise NotImplementedError("tpu_shard_context cannot be nested.")
raise NotImplementedError(
"tpu_shard_context cannot be nested."
"If you're using TPUEstimator with inference_on_tpu, "
"make sure you have set "
"export_saved_model_api_version=ExportSavedModelApiVersion.V2 in "
"the creation of TPUEstimator.")
try:
_current_tpu_context.set_number_of_shards(number_of_shards)
yield