Python tf.config tf32 interface

This commit is contained in:
Nathan Luehr 2020-05-15 13:33:02 -05:00
parent d2afc9ce83
commit 16033c0b34
3 changed files with 59 additions and 0 deletions

View File

@ -788,6 +788,16 @@ tf_python_pybind_extension(
],
)
tf_python_pybind_extension(
name = "_pywrap_tf32_execution",
srcs = ["util/tf32.cc"],
module_name = "_pywrap_tf32_execution",
deps = [
"//tensorflow/core/platform:tf32_utils",
"@pybind11",
],
)
tf_python_pybind_extension(
name = "_pywrap_util_port",
srcs = ["util/port_wrapper.cc"],
@ -5678,6 +5688,7 @@ py_library(
"//tensorflow:composite_tensor_whitelist",
],
deps = [
":_pywrap_tf32_execution",
":tf_decorator",
":tf_export",
":tf_stack",

View File

@ -18,10 +18,36 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import _pywrap_tf32_execution
from tensorflow.python.eager import context
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
def tensor_float32_execution_allowed():
"""Get if TensorFloat-32 operations are enabled on supported hardware.
Returns:
True if TensorFloat-32 execution is enabled and False otherwise.
"""
return _pywrap_tf32_execution.is_allowed()
def allow_tensor_float_32_execution(allow):
"""Allow use of TensorFloat-32 with float32 ops on supported hardware.
TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture.
TensorFloat-32 kernels take float32 inputs and produce float32 outputs.
Internally, the inputs are cast to a custom representation with 10-bit
mantissa (similar to float16) and 8-bit exponent (similar to float32) and are
executed using TensorCores with float32 accumulation. For more information,
see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/.
TensorFloat-32 execution is disabled by default, but this may change in a
future version.
Args:
allow: whether to allow TensorFloat-32 execution
"""
_pywrap_tf32_execution.allow(allow)
@tf_export('config.threading.get_intra_op_parallelism_threads')
def get_intra_op_parallelism_threads():

View File

@ -0,0 +1,22 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/pybind11.h"
#include "tensorflow/core/platform/tf32_utils.h"
PYBIND11_MODULE(_pywrap_tf32_execution, m) {
m.def("allow", &tensorflow::allow_tf32_execution);
m.def("is_allowed", &tensorflow::tf32_execution_allowed);
}