Address review comments
This commit is contained in:
parent
03d836e261
commit
7791e36a57
@ -14,14 +14,16 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/tf32_utils.h"
|
||||
#include <atomic>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(nluehr): enable tf32 execution by default after TF32 Ampere testing.
|
||||
static bool tf32_enabled = false;
|
||||
// Whether TensorFloat-32 should be used where supported.
|
||||
// TODO(nluehr): Maybe enable by default after TF32 Ampere testing.
|
||||
static std::atomic<bool> tf32_allowed{false};
|
||||
|
||||
void allow_tf32_execution(bool allow) { tf32_enabled = allow; }
|
||||
void allow_tf32_execution(bool allowed) { tf32_allowed = allowed; }
|
||||
|
||||
bool tf32_execution_allowed() { return tf32_enabled; }
|
||||
bool tf32_execution_allowed() { return tf32_allowed; }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void allow_tf32_execution(bool allow);
|
||||
void allow_tf32_execution(bool allowed);
|
||||
|
||||
bool tf32_execution_allowed();
|
||||
|
||||
|
@ -23,6 +23,8 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# No tf_export until TF is built against CUDA11 which is required for TF32.
|
||||
def tensor_float32_execution_allowed():
|
||||
"""Get if TensorFloat-32 operations are enabled on supported hardware.
|
||||
|
||||
@ -31,7 +33,8 @@ def tensor_float32_execution_allowed():
|
||||
"""
|
||||
return _pywrap_tf32_execution.is_allowed()
|
||||
|
||||
def allow_tensor_float_32_execution(allow):
|
||||
# No tf_export until TF is built against CUDA11 which is required for TF32.
|
||||
def allow_tensor_float_32_execution(allowed):
|
||||
"""Allow use of TensorFloat-32 with float32 ops on supported hardware.
|
||||
|
||||
TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture.
|
||||
@ -47,7 +50,7 @@ def allow_tensor_float_32_execution(allow):
|
||||
Args:
|
||||
allow: whether to allow TensorFloat-32 execution
|
||||
"""
|
||||
_pywrap_tf32_execution.allow(allow)
|
||||
_pywrap_tf32_execution.allow(allowed)
|
||||
|
||||
@tf_export('config.threading.get_intra_op_parallelism_threads')
|
||||
def get_intra_op_parallelism_threads():
|
||||
|
Loading…
Reference in New Issue
Block a user