Python tf.config tf32 interface
This commit is contained in:
parent
d2afc9ce83
commit
16033c0b34
|
@ -788,6 +788,16 @@ tf_python_pybind_extension(
|
|||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_tf32_execution",
|
||||
srcs = ["util/tf32.cc"],
|
||||
module_name = "_pywrap_tf32_execution",
|
||||
deps = [
|
||||
"//tensorflow/core/platform:tf32_utils",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_util_port",
|
||||
srcs = ["util/port_wrapper.cc"],
|
||||
|
@ -5678,6 +5688,7 @@ py_library(
|
|||
"//tensorflow:composite_tensor_whitelist",
|
||||
],
|
||||
deps = [
|
||||
":_pywrap_tf32_execution",
|
||||
":tf_decorator",
|
||||
":tf_export",
|
||||
":tf_stack",
|
||||
|
|
|
@ -18,10 +18,36 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import _pywrap_tf32_execution
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
def tensor_float32_execution_allowed():
|
||||
"""Get if TensorFloat-32 operations are enabled on supported hardware.
|
||||
|
||||
Returns:
|
||||
True if TensorFloat-32 execution is enabled and False otherwise.
|
||||
"""
|
||||
return _pywrap_tf32_execution.is_allowed()
|
||||
|
||||
def allow_tensor_float_32_execution(allow):
|
||||
"""Allow use of TensorFloat-32 with float32 ops on supported hardware.
|
||||
|
||||
TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture.
|
||||
TensorFloat-32 kernels take float32 inputs and produce float32 outputs.
|
||||
Internally, the inputs are cast to a custom representation with 10-bit
|
||||
mantissa (similar to float16) and 8-bit exponent (similar to float32) and are
|
||||
executed using TensorCores with float32 accumulation. For more information,
|
||||
see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/.
|
||||
|
||||
TensorFloat-32 execution is disabled by default, but this may change in a
|
||||
future version.
|
||||
|
||||
Args:
|
||||
allow: whether to allow TensorFloat-32 execution
|
||||
"""
|
||||
_pywrap_tf32_execution.allow(allow)
|
||||
|
||||
@tf_export('config.threading.get_intra_op_parallelism_threads')
|
||||
def get_intra_op_parallelism_threads():
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/core/platform/tf32_utils.h"
|
||||
|
||||
PYBIND11_MODULE(_pywrap_tf32_execution, m) {
|
||||
m.def("allow", &tensorflow::allow_tf32_execution);
|
||||
m.def("is_allowed", &tensorflow::tf32_execution_allowed);
|
||||
}
|
Loading…
Reference in New Issue