Address review comments

This commit is contained in:
Nathan Luehr 2020-06-08 11:21:49 -05:00
parent 03d836e261
commit 7791e36a57
3 changed files with 12 additions and 7 deletions

View File

@ -14,14 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/tf32_utils.h"
#include <atomic>
namespace tensorflow {
// TODO(nluehr): enable tf32 execution by default after TF32 Ampere testing.
static bool tf32_enabled = false;
// Whether TensorFloat-32 should be used where supported.
// TODO(nluehr): Maybe enable by default after TF32 Ampere testing.
static std::atomic<bool> tf32_allowed{false};
void allow_tf32_execution(bool allow) { tf32_enabled = allow; }
void allow_tf32_execution(bool allowed) { tf32_allowed = allowed; }
bool tf32_execution_allowed() { return tf32_enabled; }
bool tf32_execution_allowed() { return tf32_allowed; }
} // namespace tensorflow

View File

@ -18,7 +18,7 @@ limitations under the License.
namespace tensorflow {
void allow_tf32_execution(bool allow);
void allow_tf32_execution(bool allowed);
bool tf32_execution_allowed();

View File

@ -23,6 +23,8 @@ from tensorflow.python.eager import context
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# No tf_export until TF is built against CUDA11 which is required for TF32.
def tensor_float32_execution_allowed():
"""Get if TensorFloat-32 operations are enabled on supported hardware.
@ -31,7 +33,8 @@ def tensor_float32_execution_allowed():
"""
return _pywrap_tf32_execution.is_allowed()
def allow_tensor_float_32_execution(allow):
# No tf_export until TF is built against CUDA11 which is required for TF32.
def allow_tensor_float_32_execution(allowed):
"""Allow use of TensorFloat-32 with float32 ops on supported hardware.
TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture.
@ -47,7 +50,7 @@ def allow_tensor_float_32_execution(allow):
Args:
allow: whether to allow TensorFloat-32 execution
"""
_pywrap_tf32_execution.allow(allow)
_pywrap_tf32_execution.allow(allowed)
@tf_export('config.threading.get_intra_op_parallelism_threads')
def get_intra_op_parallelism_threads():