diff --git a/tensorflow/core/ops/tpu_configuration_ops.cc b/tensorflow/core/ops/tpu_configuration_ops.cc index 87607bd37df..ae6e2b9b1b5 100644 --- a/tensorflow/core/ops/tpu_configuration_ops.cc +++ b/tensorflow/core/ops/tpu_configuration_ops.cc @@ -205,6 +205,7 @@ REGISTER_OP("ConfigureDistributedTPU") .Attr("tpu_embedding_config: string = ''") .Attr("is_global_init: bool = false") .Attr("enable_whole_mesh_compilations: bool = false") + .Attr("compilation_failure_closes_chips: bool = true") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape); diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index 08ab0465c96..9aa1fc5ef5b 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -92,7 +92,9 @@ def _tpu_system_device_name(job): @tf_export(v1=["tpu.initialize_system"]) -def initialize_system(embedding_config=None, job=None): +def initialize_system(embedding_config=None, + job=None, + compilation_failure_closes_chips=True): """Initializes a distributed TPU system for use with TensorFlow. Args: @@ -103,6 +105,8 @@ def initialize_system(embedding_config=None, job=None): contains the TPU devices that will be initialized. If job=None it is assumed there is only one job in the TensorFlow flock, and an error will be returned if this assumption does not hold. + compilation_failure_closes_chips: Set the configuration whether + we want to close TPU chips when there is a compilation failure. Returns: A serialized `TopologyProto` that describes the TPU system. Note: the topology must be evaluated using `Session.run` before it can be used. @@ -110,7 +114,9 @@ def initialize_system(embedding_config=None, job=None): config_string = ("" if embedding_config is None else embedding_config.SerializeToString()) with ops.device(_tpu_system_device_name(job)): - return tpu_ops.configure_distributed_tpu(embedding_config=config_string) + return tpu_ops.configure_distributed_tpu( + embedding_config=config_string, + compilation_failure_closes_chips=compilation_failure_closes_chips) def initialize_system_for_tpu_embedding(embedding_config, job=None): diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py index dd5585c49b9..f3f1e78251f 100644 --- a/tensorflow/python/tpu/tpu_strategy_util.py +++ b/tensorflow/python/tpu/tpu_strategy_util.py @@ -84,7 +84,11 @@ def initialize_tpu_system(cluster_resolver=None): @function.defun def _tpu_init_fn(): - return tpu.initialize_system(job=job) + # In TF1, we usually close chips when compilation fails to clear the data + # in infeed. In TF2, we don't need to do this because infeed is no longer + # used, so user can recover from TPU compilation failures more smoothly. + return tpu.initialize_system( + job=job, compilation_failure_closes_chips=False) # The TPU_SYSTEM device must match the device used in tpu.initialize_system # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index a177114571d..89a81e9148e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -730,7 +730,7 @@ tf_module { } member_method { name: "ConfigureDistributedTPU" - argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'enable_whole_mesh_compilations\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'False\', \'None\'], " + argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'enable_whole_mesh_compilations\', \'compilation_failure_closes_chips\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'False\', \'True\', \'None\'], " } member_method { name: "ConfigureTPUEmbedding" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.pbtxt index 9890d3bc929..f7a1a1a772c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.pbtxt @@ -26,7 +26,7 @@ tf_module { } member_method { name: "initialize_system" - argspec: "args=[\'embedding_config\', \'job\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'embedding_config\', \'job\', \'compilation_failure_closes_chips\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " } member_method { name: "outside_compilation" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index a177114571d..89a81e9148e 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -730,7 +730,7 @@ tf_module { } member_method { name: "ConfigureDistributedTPU" - argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'enable_whole_mesh_compilations\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'False\', \'None\'], " + argspec: "args=[\'embedding_config\', \'tpu_embedding_config\', \'is_global_init\', \'enable_whole_mesh_compilations\', \'compilation_failure_closes_chips\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'False\', \'False\', \'True\', \'None\'], " } member_method { name: "ConfigureTPUEmbedding"