diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index de9cf9a24c7..5f9e2dfb1ff 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -788,6 +788,16 @@ tf_python_pybind_extension( ], ) +tf_python_pybind_extension( + name = "_pywrap_tf32_execution", + srcs = ["util/tf32.cc"], + module_name = "_pywrap_tf32_execution", + deps = [ + "//tensorflow/core/platform:tf32_utils", + "@pybind11", + ], +) + tf_python_pybind_extension( name = "_pywrap_util_port", srcs = ["util/port_wrapper.cc"], @@ -5678,6 +5688,7 @@ py_library( "//tensorflow:composite_tensor_whitelist", ], deps = [ + ":_pywrap_tf32_execution", ":tf_decorator", ":tf_export", ":tf_stack", diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 9ff16f2a327..cb95965dfb2 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -18,10 +18,36 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import _pywrap_tf32_execution from tensorflow.python.eager import context from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export +def tensor_float32_execution_allowed(): + """Get if TensorFloat-32 operations are enabled on supported hardware. + + Returns: + True if TensorFloat-32 execution is enabled and False otherwise. + """ + return _pywrap_tf32_execution.is_allowed() + +def allow_tensor_float_32_execution(allow): + """Allow use of TensorFloat-32 with float32 ops on supported hardware. + + TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture. + TensorFloat-32 kernels take float32 inputs and produce float32 outputs. + Internally, the inputs are cast to a custom representation with 10-bit + mantissa (similar to float16) and 8-bit exponent (similar to float32) and are + executed using TensorCores with float32 accumulation. For more information, + see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/. + + TensorFloat-32 execution is disabled by default, but this may change in a + future version. + + Args: + allow: whether to allow TensorFloat-32 execution + """ + _pywrap_tf32_execution.allow(allow) @tf_export('config.threading.get_intra_op_parallelism_threads') def get_intra_op_parallelism_threads(): diff --git a/tensorflow/python/util/tf32.cc b/tensorflow/python/util/tf32.cc new file mode 100644 index 00000000000..7dece6ccdae --- /dev/null +++ b/tensorflow/python/util/tf32.cc @@ -0,0 +1,22 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "tensorflow/core/platform/tf32_utils.h" + +PYBIND11_MODULE(_pywrap_tf32_execution, m) { + m.def("allow", &tensorflow::allow_tf32_execution); + m.def("is_allowed", &tensorflow::tf32_execution_allowed); +}