Add ability to detect if compiled with nvcc
PiperOrigin-RevId: 277807909 Change-Id: I26e973356fe3f839b71e6cdafd93f900f874b078
This commit is contained in:
parent
c92be0fe5e
commit
161245b5a5
tensorflow
@ -34,6 +34,14 @@ bool IsBuiltWithROCm() {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsBuiltWithNvcc() {
|
||||
#if TENSORFLOW_USE_NVCC
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool GpuSupportsHalfMatMulAndConv() {
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
|
@ -24,6 +24,9 @@ bool IsGoogleCudaEnabled();
|
||||
// Returns true if TENSORFLOW_USE_ROCM is defined. (i.e. TF is built with ROCm)
|
||||
bool IsBuiltWithROCm();
|
||||
|
||||
// Returns true if TENSORFLOW_USE_NVCC is defined. (i.e. TF is built with nvcc)
|
||||
bool IsBuiltWithNvcc();
|
||||
|
||||
// Returns true if either
|
||||
//
|
||||
// GOOGLE_CUDA is defined, and the given CUDA version supports
|
||||
|
@ -287,6 +287,10 @@ def IsBuiltWithROCm():
|
||||
return _pywrap_util_port.IsBuiltWithROCm()
|
||||
|
||||
|
||||
def IsBuiltWithNvcc():
|
||||
return _pywrap_util_port.IsBuiltWithNvcc()
|
||||
|
||||
|
||||
def GpuSupportsHalfMatMulAndConv():
|
||||
return _pywrap_util_port.GpuSupportsHalfMatMulAndConv()
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
PYBIND11_MODULE(_pywrap_util_port, m) {
|
||||
m.def("IsGoogleCudaEnabled", tensorflow::IsGoogleCudaEnabled);
|
||||
m.def("IsBuiltWithROCm", tensorflow::IsBuiltWithROCm);
|
||||
m.def("IsBuiltWithNvcc", tensorflow::IsBuiltWithNvcc);
|
||||
m.def("GpuSupportsHalfMatMulAndConv",
|
||||
tensorflow::GpuSupportsHalfMatMulAndConv);
|
||||
m.def("IsMklEnabled", tensorflow::IsMklEnabled);
|
||||
|
@ -69,6 +69,12 @@ def if_not_v2(a):
|
||||
"//conditions:default": a,
|
||||
})
|
||||
|
||||
def if_nvcc(a):
|
||||
return select({
|
||||
"@local_config_cuda//cuda:using_nvcc": a,
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def if_cuda_is_configured_compat(x):
|
||||
return if_cuda_is_configured(x)
|
||||
|
||||
@ -287,6 +293,7 @@ def tf_copts(
|
||||
]) +
|
||||
(if_not_windows(["-fno-exceptions"]) if not allow_exceptions else []) +
|
||||
if_cuda(["-DGOOGLE_CUDA=1"]) +
|
||||
if_nvcc(["-DTENSORFLOW_USE_NVCC=1"]) +
|
||||
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
|
||||
if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
|
||||
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
|
||||
|
@ -21,6 +21,7 @@ tensorflow::swig::RegisterType
|
||||
[util_port] # util_port
|
||||
tensorflow::IsGoogleCudaEnabled
|
||||
tensorflow::IsBuiltWithROCm
|
||||
tensorflow::IsBuiltWithNvcc
|
||||
tensorflow::GpuSupportsHalfMatMulAndConv
|
||||
tensorflow::IsMklEnabled
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user