Merge pull request from Intel-tensorflow:mazhar/auto_mixed_preci_bdw

PiperOrigin-RevId: 349277237
Change-Id: I2019e57e3280ff48886e63868aba62407c2676db
This commit is contained in:
TensorFlower Gardener 2020-12-28 06:46:19 -08:00
commit aa57d701e0
3 changed files with 26 additions and 0 deletions

View File

@ -50,6 +50,7 @@ from tensorflow.python.platform import sysconfig
from tensorflow.python.platform import test
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.util import _pywrap_utils
def _input(shape):
@ -371,6 +372,10 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
self.skipTest('No GPU is available')
if mode == 'mkl' and not test_util.IsMklEnabled():
self.skipTest('MKL is not enabled')
# Test will fail on machines without AVX512f, e.g., Broadwell
isAVX512f = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()
if mode == 'mkl' and not isAVX512f:
self.skipTest('Skipping test due to non-AVX512f machine')
def _run_simple_loop_test(self, mode, inp, body, out):
"""Runs a test of a simple loop.

View File

@ -73,6 +73,7 @@ tf_python_pybind_extension(
hdrs = ["util.h"],
module_name = "_pywrap_utils",
deps = [
"//tensorflow/core/platform:platform_port",
"//tensorflow/python:pybind11_lib",
"//third_party/python_runtime:headers",
"@pybind11",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/util/util.h"
@ -348,4 +349,23 @@ PYBIND11_MODULE(_pywrap_utils, m) {
Returns:
True if `instance` is a `Variable`.
)pbdoc");
m.def(
"IsBF16SupportedByOneDNNOnThisCPU",
[]() {
bool result = tensorflow::port::TestCPUFeature(
tensorflow::port::CPUFeature::AVX512F);
if (PyErr_Occurred()) {
throw py::error_already_set();
}
return result;
},
R"pbdoc(
Returns 1 if CPU has avx512f feature.
Args:
None
Returns:
True if CPU has avx512f feature.
)pbdoc");
}