Address review comments

This commit is contained in:
Nathan Luehr 2020-06-08 11:21:49 -05:00
parent 03d836e261
commit 7791e36a57
3 changed files with 12 additions and 7 deletions

View File

@ -14,14 +14,16 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/platform/tf32_utils.h" #include "tensorflow/core/platform/tf32_utils.h"
#include <atomic>
namespace tensorflow { namespace tensorflow {
// TODO(nluehr): enable tf32 execution by default after TF32 Ampere testing. // Whether TensorFloat-32 should be used where supported.
static bool tf32_enabled = false; // TODO(nluehr): Maybe enable by default after TF32 Ampere testing.
static std::atomic<bool> tf32_allowed{false};
void allow_tf32_execution(bool allow) { tf32_enabled = allow; } void allow_tf32_execution(bool allowed) { tf32_allowed = allowed; }
bool tf32_execution_allowed() { return tf32_enabled; } bool tf32_execution_allowed() { return tf32_allowed; }
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,7 +18,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
void allow_tf32_execution(bool allow); void allow_tf32_execution(bool allowed);
bool tf32_execution_allowed(); bool tf32_execution_allowed();

View File

@ -23,6 +23,8 @@ from tensorflow.python.eager import context
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# No tf_export until TF is built against CUDA11 which is required for TF32.
def tensor_float32_execution_allowed(): def tensor_float32_execution_allowed():
"""Get if TensorFloat-32 operations are enabled on supported hardware. """Get if TensorFloat-32 operations are enabled on supported hardware.
@ -31,7 +33,8 @@ def tensor_float32_execution_allowed():
""" """
return _pywrap_tf32_execution.is_allowed() return _pywrap_tf32_execution.is_allowed()
def allow_tensor_float_32_execution(allow): # No tf_export until TF is built against CUDA11 which is required for TF32.
def allow_tensor_float_32_execution(allowed):
"""Allow use of TensorFloat-32 with float32 ops on supported hardware. """Allow use of TensorFloat-32 with float32 ops on supported hardware.
TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture. TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture.
@ -47,7 +50,7 @@ def allow_tensor_float_32_execution(allow):
Args: Args:
allow: whether to allow TensorFloat-32 execution allow: whether to allow TensorFloat-32 execution
""" """
_pywrap_tf32_execution.allow(allow) _pywrap_tf32_execution.allow(allowed)
@tf_export('config.threading.get_intra_op_parallelism_threads') @tf_export('config.threading.get_intra_op_parallelism_threads')
def get_intra_op_parallelism_threads(): def get_intra_op_parallelism_threads():