Adds a reminder to set export_saved_model_api_version=ExportSavedModelApiVersion.V2 when using TPUEstimator's inference_on_tpu.
PiperOrigin-RevId: 345329997 Change-Id: I00b859503a09ce1bab50afc223318afb05d563bf
This commit is contained in:
parent
effe8fd10c
commit
254c5b9da2
@ -45,8 +45,14 @@ _current_tpu_context = TpuContext()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tpu_shard_context(number_of_shards):
|
||||
"""A context manager setting current number of shards."""
|
||||
if _current_tpu_context.number_of_shards is not None:
|
||||
raise NotImplementedError("tpu_shard_context cannot be nested.")
|
||||
raise NotImplementedError(
|
||||
"tpu_shard_context cannot be nested."
|
||||
"If you're using TPUEstimator with inference_on_tpu, "
|
||||
"make sure you have set "
|
||||
"export_saved_model_api_version=ExportSavedModelApiVersion.V2 in "
|
||||
"the creation of TPUEstimator.")
|
||||
try:
|
||||
_current_tpu_context.set_number_of_shards(number_of_shards)
|
||||
yield
|
||||
|
Loading…
Reference in New Issue
Block a user