diff --git a/tensorflow/core/platform/tf32_utils.cc b/tensorflow/core/platform/tf32_utils.cc index 715b5996dc3..4456e768c0a 100644 --- a/tensorflow/core/platform/tf32_utils.cc +++ b/tensorflow/core/platform/tf32_utils.cc @@ -14,14 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/tf32_utils.h" +#include namespace tensorflow { -// TODO(nluehr): enable tf32 execution by default after TF32 Ampere testing. -static bool tf32_enabled = false; +// Whether TensorFloat-32 should be used where supported. +// TODO(nluehr): Maybe enable by default after TF32 Ampere testing. +static std::atomic tf32_allowed{false}; -void allow_tf32_execution(bool allow) { tf32_enabled = allow; } +void allow_tf32_execution(bool allowed) { tf32_allowed = allowed; } -bool tf32_execution_allowed() { return tf32_enabled; } +bool tf32_execution_allowed() { return tf32_allowed; } } // namespace tensorflow diff --git a/tensorflow/core/platform/tf32_utils.h b/tensorflow/core/platform/tf32_utils.h index a0ce58f9bbd..7a158d00ad3 100644 --- a/tensorflow/core/platform/tf32_utils.h +++ b/tensorflow/core/platform/tf32_utils.h @@ -18,7 +18,7 @@ limitations under the License. namespace tensorflow { -void allow_tf32_execution(bool allow); +void allow_tf32_execution(bool allowed); bool tf32_execution_allowed(); diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index cb95965dfb2..e80ad1d72c4 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -23,6 +23,8 @@ from tensorflow.python.eager import context from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export + +# No tf_export until TF is built against CUDA11 which is required for TF32. def tensor_float32_execution_allowed(): """Get if TensorFloat-32 operations are enabled on supported hardware. @@ -31,7 +33,8 @@ def tensor_float32_execution_allowed(): """ return _pywrap_tf32_execution.is_allowed() -def allow_tensor_float_32_execution(allow): +# No tf_export until TF is built against CUDA11 which is required for TF32. +def allow_tensor_float_32_execution(allowed): """Allow use of TensorFloat-32 with float32 ops on supported hardware. TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture. @@ -47,7 +50,7 @@ def allow_tensor_float_32_execution(allow): Args: allow: whether to allow TensorFloat-32 execution """ - _pywrap_tf32_execution.allow(allow) + _pywrap_tf32_execution.allow(allowed) @tf_export('config.threading.get_intra_op_parallelism_threads') def get_intra_op_parallelism_threads():