Python tf.config tf32 interface
This commit is contained in:
parent
d2afc9ce83
commit
16033c0b34
|
@ -788,6 +788,16 @@ tf_python_pybind_extension(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_tf32_execution",
|
||||||
|
srcs = ["util/tf32.cc"],
|
||||||
|
module_name = "_pywrap_tf32_execution",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core/platform:tf32_utils",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_python_pybind_extension(
|
tf_python_pybind_extension(
|
||||||
name = "_pywrap_util_port",
|
name = "_pywrap_util_port",
|
||||||
srcs = ["util/port_wrapper.cc"],
|
srcs = ["util/port_wrapper.cc"],
|
||||||
|
@ -5678,6 +5688,7 @@ py_library(
|
||||||
"//tensorflow:composite_tensor_whitelist",
|
"//tensorflow:composite_tensor_whitelist",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":_pywrap_tf32_execution",
|
||||||
":tf_decorator",
|
":tf_decorator",
|
||||||
":tf_export",
|
":tf_export",
|
||||||
":tf_stack",
|
":tf_stack",
|
||||||
|
|
|
@ -18,10 +18,36 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python import _pywrap_tf32_execution
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
def tensor_float32_execution_allowed():
|
||||||
|
"""Get if TensorFloat-32 operations are enabled on supported hardware.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if TensorFloat-32 execution is enabled and False otherwise.
|
||||||
|
"""
|
||||||
|
return _pywrap_tf32_execution.is_allowed()
|
||||||
|
|
||||||
|
def allow_tensor_float_32_execution(allow):
|
||||||
|
"""Allow use of TensorFloat-32 with float32 ops on supported hardware.
|
||||||
|
|
||||||
|
TensorFloat-32 is a math mode introduced with the NVIDIA Ampere architecture.
|
||||||
|
TensorFloat-32 kernels take float32 inputs and produce float32 outputs.
|
||||||
|
Internally, the inputs are cast to a custom representation with 10-bit
|
||||||
|
mantissa (similar to float16) and 8-bit exponent (similar to float32) and are
|
||||||
|
executed using TensorCores with float32 accumulation. For more information,
|
||||||
|
see https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/.
|
||||||
|
|
||||||
|
TensorFloat-32 execution is disabled by default, but this may change in a
|
||||||
|
future version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allow: whether to allow TensorFloat-32 execution
|
||||||
|
"""
|
||||||
|
_pywrap_tf32_execution.allow(allow)
|
||||||
|
|
||||||
@tf_export('config.threading.get_intra_op_parallelism_threads')
|
@tf_export('config.threading.get_intra_op_parallelism_threads')
|
||||||
def get_intra_op_parallelism_threads():
|
def get_intra_op_parallelism_threads():
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/core/platform/tf32_utils.h"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_tf32_execution, m) {
|
||||||
|
m.def("allow", &tensorflow::allow_tf32_execution);
|
||||||
|
m.def("is_allowed", &tensorflow::tf32_execution_allowed);
|
||||||
|
}
|
Loading…
Reference in New Issue