Don't shut down TPU chips if compilation fails in TF2.
PiperOrigin-RevId: 276630160 Change-Id: I02c477e7fcd7a936821affa953a9801deaf7a6ec
This commit is contained in:
parent
bd422caa62
commit
830639c0e9
@ -205,6 +205,7 @@ REGISTER_OP("ConfigureDistributedTPU")
|
||||
.Attr("tpu_embedding_config: string = ''")
|
||||
.Attr("is_global_init: bool = false")
|
||||
.Attr("enable_whole_mesh_compilations: bool = false")
|
||||
.Attr("compilation_failure_closes_chips: bool = true")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
|
@ -92,7 +92,9 @@ def _tpu_system_device_name(job):
|
||||
|
||||
|
||||
@tf_export(v1=["tpu.initialize_system"])
|
||||
def initialize_system(embedding_config=None, job=None):
|
||||
def initialize_system(embedding_config=None,
|
||||
job=None,
|
||||
compilation_failure_closes_chips=True):
|
||||
"""Initializes a distributed TPU system for use with TensorFlow.
|
||||
|
||||
Args:
|
||||
@ -103,6 +105,8 @@ def initialize_system(embedding_config=None, job=None):
|
||||
contains the TPU devices that will be initialized. If job=None it is
|
||||
assumed there is only one job in the TensorFlow flock, and an error will
|
||||
be returned if this assumption does not hold.
|
||||
compilation_failure_closes_chips: Set the configuration whether
|
||||
we want to close TPU chips when there is a compilation failure.
|
||||
Returns:
|
||||
A serialized `TopologyProto` that describes the TPU system. Note:
|
||||
the topology must be evaluated using `Session.run` before it can be used.
|
||||
@ -110,7 +114,9 @@ def initialize_system(embedding_config=None, job=None):
|
||||
config_string = ("" if embedding_config is None else
|
||||
embedding_config.SerializeToString())
|
||||
with ops.device(_tpu_system_device_name(job)):
|
||||
return tpu_ops.configure_distributed_tpu(embedding_config=config_string)
|
||||
return tpu_ops.configure_distributed_tpu(
|
||||
embedding_config=config_string,
|
||||
compilation_failure_closes_chips=compilation_failure_closes_chips)
|
||||
|
||||
|
||||
def initialize_system_for_tpu_embedding(embedding_config, job=None):
|
||||
|
@ -84,7 +84,11 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
|
||||
@function.defun
|
||||
def _tpu_init_fn():
|
||||
return tpu.initialize_system(job=job)
|
||||
# In TF1, we usually close chips when compilation fails to clear the data
|
||||
# in infeed. In TF2, we don't need to do this because infeed is no longer
|
||||
# used, so user can recover from TPU compilation failures more smoothly.
|
||||
return tpu.initialize_system(
|
||||
job=job, compilation_failure_closes_chips=False)
|
||||
|
||||
# The TPU_SYSTEM device must match the device used in tpu.initialize_system
|
||||
# exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
|
||||
|
@ -730,7 +730,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ConfigureDistributedTPU"
|
||||
argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'enable_whole_mesh_compilations\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'enable_whole_mesh_compilations\', \'compilation_failure_closes_chips\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ConfigureTPUEmbedding"
|
||||
|
@ -26,7 +26,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "initialize_system"
|
||||
argspec: "args=[\'embedding_config\', \'job\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'embedding_config\', \'job\', \'compilation_failure_closes_chips\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "outside_compilation"
|
||||
|
@ -730,7 +730,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ConfigureDistributedTPU"
|
||||
argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'enable_whole_mesh_compilations\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'enable_whole_mesh_compilations\', \'compilation_failure_closes_chips\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ConfigureTPUEmbedding"
|
||||
|
Loading…
Reference in New Issue
Block a user